mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-17 03:04:20 +00:00
feat(agent-runner): add host run ledger primitives
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Agent runner modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .descriptor import AgentRunnerDescriptor
|
||||
@@ -24,6 +25,7 @@ from .session_registry import (
|
||||
RunAuthorizationSnapshot,
|
||||
get_session_registry,
|
||||
)
|
||||
from .run_ledger_store import RunLedgerStore
|
||||
from .events import (
|
||||
MESSAGE_RECEIVED,
|
||||
MESSAGE_RECALLED,
|
||||
@@ -55,6 +57,7 @@ __all__ = [
|
||||
'AgentRunSession',
|
||||
'RunAuthorizationSnapshot',
|
||||
'get_session_registry',
|
||||
'RunLedgerStore',
|
||||
'MESSAGE_RECEIVED',
|
||||
'MESSAGE_RECALLED',
|
||||
'GROUP_MEMBER_JOINED',
|
||||
|
||||
@@ -420,7 +420,21 @@ class AgentRunContextBuilder:
|
||||
event_page_enabled = 'page' in event_perms and conversation_id is not None
|
||||
artifact_metadata_enabled = 'metadata' in artifact_perms
|
||||
artifact_read_enabled = 'read' in artifact_perms
|
||||
steering_pull_enabled = bool(getattr(descriptor.capabilities, 'steering', False)) and conversation_id is not None
|
||||
steering_pull_enabled = (
|
||||
bool(getattr(descriptor.capabilities, 'steering', False)) and conversation_id is not None
|
||||
)
|
||||
run_get_enabled = True
|
||||
run_list_enabled = conversation_id is not None
|
||||
run_events_page_enabled = True
|
||||
run_cancel_enabled = True
|
||||
run_append_result_enabled = False
|
||||
run_finalize_enabled = False
|
||||
run_claim_enabled = False
|
||||
run_renew_claim_enabled = False
|
||||
run_release_claim_enabled = False
|
||||
runtime_register_enabled = False
|
||||
runtime_heartbeat_enabled = False
|
||||
runtime_list_enabled = False
|
||||
|
||||
# Determine state API availability based on binding state_policy.
|
||||
state_enabled = False
|
||||
@@ -431,9 +445,8 @@ class AgentRunContextBuilder:
|
||||
state_enabled = True
|
||||
|
||||
resource_policy = binding.resource_policy
|
||||
storage_enabled = (
|
||||
('plugin' in storage_perms and resource_policy.allow_plugin_storage)
|
||||
or ('workspace' in storage_perms and resource_policy.allow_workspace_storage)
|
||||
storage_enabled = ('plugin' in storage_perms and resource_policy.allow_plugin_storage) or (
|
||||
'workspace' in storage_perms and resource_policy.allow_workspace_storage
|
||||
)
|
||||
|
||||
# Get latest cursor and has_history_before if conversation exists
|
||||
@@ -477,5 +490,17 @@ class AgentRunContextBuilder:
|
||||
'state': state_enabled,
|
||||
'storage': storage_enabled,
|
||||
'steering_pull': steering_pull_enabled,
|
||||
'run_get': run_get_enabled,
|
||||
'run_list': run_list_enabled,
|
||||
'run_events_page': run_events_page_enabled,
|
||||
'run_cancel': run_cancel_enabled,
|
||||
'run_append_result': run_append_result_enabled,
|
||||
'run_finalize': run_finalize_enabled,
|
||||
'run_claim': run_claim_enabled,
|
||||
'run_renew_claim': run_renew_claim_enabled,
|
||||
'run_release_claim': run_release_claim_enabled,
|
||||
'runtime_register': runtime_register_enabled,
|
||||
'runtime_heartbeat': runtime_heartbeat_enabled,
|
||||
'runtime_list': runtime_list_enabled,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -103,13 +103,40 @@ class AgentRunOrchestrator:
|
||||
|
||||
state_context = build_state_context(event, binding, descriptor)
|
||||
run_id = context['run_id']
|
||||
available_apis = context.get('context', {}).get('available_apis')
|
||||
run_authorization = {
|
||||
'runner_id': descriptor.id,
|
||||
'binding_id': binding.binding_id,
|
||||
'plugin_identity': descriptor.get_plugin_id(),
|
||||
'resources': resources,
|
||||
'available_apis': available_apis,
|
||||
'conversation_id': event.conversation_id,
|
||||
'bot_id': event.bot_id,
|
||||
'workspace_id': event.workspace_id,
|
||||
'thread_id': event.thread_id,
|
||||
'state_policy': {
|
||||
'enable_state': binding.state_policy.enable_state,
|
||||
'state_scopes': list(binding.state_policy.state_scopes),
|
||||
},
|
||||
'state_context': state_context,
|
||||
}
|
||||
|
||||
pending_artifact_refs: list[dict[str, typing.Any]] = []
|
||||
seen_sequences: set[int] = set()
|
||||
last_sequence = 0
|
||||
assistant_transcript_written = False
|
||||
terminal_status: str | None = None
|
||||
terminal_reason: str | None = None
|
||||
terminal_usage: dict[str, typing.Any] | None = None
|
||||
|
||||
try:
|
||||
await self.journal.create_run(
|
||||
event=event,
|
||||
binding=binding,
|
||||
descriptor=descriptor,
|
||||
context=context,
|
||||
authorization=run_authorization,
|
||||
)
|
||||
await self._session_registry.register(
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
@@ -146,14 +173,15 @@ class AgentRunOrchestrator:
|
||||
)
|
||||
|
||||
async for result_dict in self.invoker.invoke(descriptor, context):
|
||||
result_dict = dict(result_dict)
|
||||
sequence = result_dict.get('sequence')
|
||||
if sequence is not None:
|
||||
try:
|
||||
sequence_int = int(sequence)
|
||||
except (TypeError, ValueError):
|
||||
self.ap.logger.warning(
|
||||
f'Runner {descriptor.id} returned invalid result sequence: {sequence}'
|
||||
)
|
||||
self.ap.logger.warning(f'Runner {descriptor.id} returned invalid result sequence: {sequence}')
|
||||
sequence_int = last_sequence + 1
|
||||
result_dict['sequence'] = sequence_int
|
||||
else:
|
||||
if sequence_int in seen_sequences:
|
||||
self.ap.logger.warning(
|
||||
@@ -166,6 +194,8 @@ class AgentRunOrchestrator:
|
||||
f'Runner {descriptor.id} returned non-positive result sequence '
|
||||
f'{sequence_int} for run {run_id}'
|
||||
)
|
||||
sequence_int = last_sequence + 1
|
||||
result_dict['sequence'] = sequence_int
|
||||
elif last_sequence and sequence_int != last_sequence + 1:
|
||||
self.ap.logger.warning(
|
||||
f'Runner {descriptor.id} result sequence gap or out-of-order '
|
||||
@@ -173,6 +203,11 @@ class AgentRunOrchestrator:
|
||||
)
|
||||
seen_sequences.add(sequence_int)
|
||||
last_sequence = max(last_sequence, sequence_int)
|
||||
else:
|
||||
sequence_int = last_sequence + 1
|
||||
result_dict['sequence'] = sequence_int
|
||||
seen_sequences.add(sequence_int)
|
||||
last_sequence = sequence_int
|
||||
|
||||
result_type = result_dict.get('type')
|
||||
if result_type and not self.result_normalizer.validate_payload(
|
||||
@@ -190,9 +225,21 @@ class AgentRunOrchestrator:
|
||||
runner_id=descriptor.id,
|
||||
)
|
||||
pending_artifact_refs.append(artifact_ref)
|
||||
await self.journal.append_run_result(
|
||||
result_dict=result_dict,
|
||||
run_id=run_id,
|
||||
sequence=sequence_int,
|
||||
artifact_refs=[artifact_ref],
|
||||
)
|
||||
await self.result_normalizer.normalize(result_dict, descriptor)
|
||||
continue
|
||||
|
||||
await self.journal.append_run_result(
|
||||
result_dict=result_dict,
|
||||
run_id=run_id,
|
||||
sequence=sequence_int,
|
||||
)
|
||||
|
||||
if result_type == 'state.updated':
|
||||
await self.journal.handle_state_updated_event(
|
||||
result_dict,
|
||||
@@ -204,13 +251,28 @@ class AgentRunOrchestrator:
|
||||
await self.result_normalizer.normalize(result_dict, descriptor)
|
||||
continue
|
||||
|
||||
has_completed_message = (
|
||||
result_type == 'message.completed'
|
||||
or (
|
||||
result_type == 'run.completed'
|
||||
and isinstance(result_dict.get('data'), dict)
|
||||
and bool(result_dict['data'].get('message'))
|
||||
if result_type == 'run.completed':
|
||||
terminal_status = 'completed'
|
||||
terminal_reason = (
|
||||
result_dict.get('data', {}).get('finish_reason')
|
||||
if isinstance(result_dict.get('data'), dict)
|
||||
else None
|
||||
)
|
||||
usage = result_dict.get('usage')
|
||||
if isinstance(usage, dict):
|
||||
terminal_usage = usage
|
||||
elif result_type == 'run.failed':
|
||||
terminal_status = 'failed'
|
||||
data = result_dict.get('data') if isinstance(result_dict.get('data'), dict) else {}
|
||||
terminal_reason = data.get('error') or data.get('code')
|
||||
usage = result_dict.get('usage')
|
||||
if isinstance(usage, dict):
|
||||
terminal_usage = usage
|
||||
|
||||
has_completed_message = result_type == 'message.completed' or (
|
||||
result_type == 'run.completed'
|
||||
and isinstance(result_dict.get('data'), dict)
|
||||
and bool(result_dict['data'].get('message'))
|
||||
)
|
||||
if has_completed_message and event.conversation_id and not assistant_transcript_written:
|
||||
merged_refs = self.journal.merge_artifact_refs(
|
||||
@@ -231,6 +293,27 @@ class AgentRunOrchestrator:
|
||||
result = await self.result_normalizer.normalize(result_dict, descriptor)
|
||||
if result is not None:
|
||||
yield result
|
||||
|
||||
run_snapshot = await self.journal.get_run(run_id)
|
||||
if run_snapshot and run_snapshot.get('cancel_requested_at') is not None:
|
||||
terminal_status = 'cancelled'
|
||||
terminal_reason = run_snapshot.get('status_reason') or 'cancel_requested'
|
||||
break
|
||||
await self.journal.finalize_run(
|
||||
run_id=run_id,
|
||||
status=terminal_status or 'completed',
|
||||
status_reason=terminal_reason,
|
||||
usage=terminal_usage,
|
||||
)
|
||||
except Exception as exc:
|
||||
failed_usage = terminal_usage
|
||||
await self.journal.finalize_run(
|
||||
run_id=run_id,
|
||||
status='timeout' if self._is_deadline_exhausted(context) else 'failed',
|
||||
status_reason=str(exc),
|
||||
usage=failed_usage,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
session = await self._session_registry.unregister(run_id)
|
||||
pending_steering = session.get('steering_queue', []) if session else []
|
||||
@@ -325,9 +408,7 @@ class AgentRunOrchestrator:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
self.ap.logger.info(
|
||||
f'Claimed event {event.event_id} as steering input for run {target_run_id}'
|
||||
)
|
||||
self.ap.logger.info(f'Claimed event {event.event_id} as steering input for run {target_run_id}')
|
||||
return True
|
||||
|
||||
def _build_steering_item(
|
||||
|
||||
@@ -9,6 +9,7 @@ from .descriptor import AgentRunnerDescriptor
|
||||
from .errors import RunnerProtocolError
|
||||
from .host_models import AgentBinding, AgentEventEnvelope
|
||||
from .persistent_state_store import PersistentStateStore, get_persistent_state_store
|
||||
from .run_ledger_store import RunLedgerStore
|
||||
|
||||
|
||||
# Maximum inline artifact content size (1MB)
|
||||
@@ -21,10 +22,17 @@ class AgentRunJournal:
|
||||
ap: app.Application
|
||||
|
||||
_persistent_state_store: PersistentStateStore | None
|
||||
_run_ledger_store: RunLedgerStore | None
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self._persistent_state_store = None
|
||||
self._run_ledger_store = None
|
||||
|
||||
def _get_run_ledger_store(self) -> RunLedgerStore:
|
||||
if self._run_ledger_store is None:
|
||||
self._run_ledger_store = RunLedgerStore(self.ap.persistence_mgr.get_db_engine())
|
||||
return self._run_ledger_store
|
||||
|
||||
@staticmethod
|
||||
def _to_plain_dict(value: typing.Any) -> dict[str, typing.Any]:
|
||||
@@ -64,6 +72,81 @@ class AgentRunJournal:
|
||||
def _sanitize_attachments(cls, attachments: typing.Iterable[typing.Any]) -> list[dict[str, typing.Any]]:
|
||||
return [cls._sanitize_attachment_ref(attachment) for attachment in attachments]
|
||||
|
||||
async def create_run(
|
||||
self,
|
||||
*,
|
||||
event: AgentEventEnvelope,
|
||||
binding: AgentBinding,
|
||||
descriptor: AgentRunnerDescriptor,
|
||||
context: dict[str, typing.Any],
|
||||
authorization: dict[str, typing.Any],
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Create the Host-owned run ledger record."""
|
||||
runtime = context.get('runtime') if isinstance(context, dict) else {}
|
||||
return await self._get_run_ledger_store().create_run(
|
||||
run_id=context['run_id'],
|
||||
event_id=event.event_id,
|
||||
binding_id=binding.binding_id,
|
||||
runner_id=descriptor.id,
|
||||
conversation_id=event.conversation_id,
|
||||
thread_id=event.thread_id,
|
||||
workspace_id=event.workspace_id,
|
||||
bot_id=event.bot_id,
|
||||
deadline_at=runtime.get('deadline_at') if isinstance(runtime, dict) else None,
|
||||
authorization=authorization,
|
||||
metadata={
|
||||
'event_type': event.event_type,
|
||||
'source': event.source,
|
||||
},
|
||||
)
|
||||
|
||||
async def append_run_result(
|
||||
self,
|
||||
*,
|
||||
result_dict: dict[str, typing.Any],
|
||||
run_id: str,
|
||||
sequence: int,
|
||||
source: str = 'runner',
|
||||
artifact_refs: list[dict[str, typing.Any]] | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Persist one AgentRunResult in the run ledger."""
|
||||
usage = result_dict.get('usage')
|
||||
if hasattr(usage, 'model_dump'):
|
||||
usage = usage.model_dump(mode='json')
|
||||
return await self._get_run_ledger_store().append_event(
|
||||
run_id=run_id,
|
||||
sequence=sequence,
|
||||
event_type=str(result_dict.get('type') or 'unknown'),
|
||||
data=result_dict.get('data') if isinstance(result_dict.get('data'), dict) else {},
|
||||
usage=usage if isinstance(usage, dict) else None,
|
||||
source=source,
|
||||
artifact_refs=artifact_refs,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def finalize_run(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
status: str,
|
||||
status_reason: str | None = None,
|
||||
usage: dict[str, typing.Any] | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Finalize or update the Host-owned run ledger record."""
|
||||
return await self._get_run_ledger_store().finalize_run(
|
||||
run_id=run_id,
|
||||
status=status,
|
||||
status_reason=status_reason,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def get_run(self, run_id: str) -> dict[str, typing.Any] | None:
|
||||
"""Return the persisted run ledger record."""
|
||||
return await self._get_run_ledger_store().get_run(run_id)
|
||||
|
||||
async def handle_state_updated_event(
|
||||
self,
|
||||
result_dict: dict[str, typing.Any],
|
||||
@@ -99,9 +182,7 @@ class AgentRunJournal:
|
||||
)
|
||||
|
||||
if self._persistent_state_store is None:
|
||||
self._persistent_state_store = get_persistent_state_store(
|
||||
self.ap.persistence_mgr.get_db_engine()
|
||||
)
|
||||
self._persistent_state_store = get_persistent_state_store(self.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
success, error = await self._persistent_state_store.apply_update_from_event(
|
||||
event=event,
|
||||
@@ -114,13 +195,9 @@ class AgentRunJournal:
|
||||
)
|
||||
|
||||
if success:
|
||||
self.ap.logger.debug(
|
||||
f'Runner {descriptor.id} state.updated (event mode): scope={scope}, key={key}'
|
||||
)
|
||||
self.ap.logger.debug(f'Runner {descriptor.id} state.updated (event mode): scope={scope}, key={key}')
|
||||
elif error:
|
||||
self.ap.logger.warning(
|
||||
f'Runner {descriptor.id} state.updated rejected: {error}'
|
||||
)
|
||||
self.ap.logger.warning(f'Runner {descriptor.id} state.updated rejected: {error}')
|
||||
|
||||
async def write_event_log(
|
||||
self,
|
||||
@@ -166,9 +243,7 @@ class AgentRunJournal:
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
event_time=(
|
||||
datetime.datetime.fromtimestamp(event.event_time, datetime.timezone.utc)
|
||||
if event.event_time
|
||||
else None
|
||||
datetime.datetime.fromtimestamp(event.event_time, datetime.timezone.utc) if event.event_time else None
|
||||
),
|
||||
metadata=metadata,
|
||||
)
|
||||
@@ -239,9 +314,7 @@ class AgentRunJournal:
|
||||
content=content,
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(
|
||||
f'Failed to register input artifact {artifact_id}: {e}'
|
||||
)
|
||||
self.ap.logger.warning(f'Failed to register input artifact {artifact_id}: {e}')
|
||||
|
||||
def decode_attachment_content(
|
||||
self,
|
||||
|
||||
@@ -0,0 +1,600 @@
|
||||
"""Run ledger store for Host-owned AgentRun and AgentRunEvent records."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from ...entity.persistence.agent_run import AgentRun, AgentRunEvent, AgentRuntime
|
||||
|
||||
|
||||
UTC = datetime.timezone.utc
|
||||
TERMINAL_STATUSES = {'completed', 'failed', 'cancelled', 'timeout'}
|
||||
|
||||
|
||||
def _utc_now() -> datetime.datetime:
|
||||
return datetime.datetime.now(UTC)
|
||||
|
||||
|
||||
def _datetime_to_epoch(value: datetime.datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
else:
|
||||
value = value.astimezone(UTC)
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _epoch_to_datetime(value: typing.Any) -> datetime.datetime | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return datetime.datetime.fromtimestamp(float(value), UTC)
|
||||
except (TypeError, ValueError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _json_dumps(value: typing.Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return json.dumps(value)
|
||||
|
||||
|
||||
def _json_loads(value: str | None, default: typing.Any) -> typing.Any:
|
||||
if not value:
|
||||
return default
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
class RunLedgerStore:
|
||||
"""Store for Host-owned run lifecycle and result event facts."""
|
||||
|
||||
engine: AsyncEngine
|
||||
|
||||
def __init__(self, engine: AsyncEngine):
|
||||
self.engine = engine
|
||||
self._session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async def create_run(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
event_id: str | None,
|
||||
binding_id: str | None,
|
||||
runner_id: str,
|
||||
conversation_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
deadline_at: int | float | None = None,
|
||||
authorization: dict[str, typing.Any] | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
status: str = 'running',
|
||||
queue_name: str | None = None,
|
||||
priority: int = 0,
|
||||
requested_runtime_id: str | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Create a run if it does not already exist."""
|
||||
now = _utc_now()
|
||||
async with self._session_factory() as session:
|
||||
existing = await self._get_run_row(session, run_id)
|
||||
if existing is not None:
|
||||
return self._run_to_dict(existing)
|
||||
|
||||
run = AgentRun(
|
||||
run_id=run_id,
|
||||
event_id=event_id,
|
||||
agent_id=agent_id,
|
||||
binding_id=binding_id,
|
||||
runner_id=runner_id,
|
||||
conversation_id=conversation_id,
|
||||
thread_id=thread_id,
|
||||
workspace_id=workspace_id,
|
||||
bot_id=bot_id,
|
||||
status=status,
|
||||
queue_name=queue_name,
|
||||
priority=priority,
|
||||
requested_runtime_id=requested_runtime_id,
|
||||
created_at=now,
|
||||
started_at=now if status == 'running' else None,
|
||||
updated_at=now,
|
||||
deadline_at=_epoch_to_datetime(deadline_at),
|
||||
authorization_json=_json_dumps(authorization),
|
||||
metadata_json=_json_dumps(metadata),
|
||||
)
|
||||
session.add(run)
|
||||
await session.commit()
|
||||
return self._run_to_dict(run)
|
||||
|
||||
async def claim_next_run(
|
||||
self,
|
||||
*,
|
||||
runtime_id: str,
|
||||
queue_name: str | None = None,
|
||||
lease_seconds: int = 60,
|
||||
runner_ids: list[str] | None = None,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Claim the next queued or expired-leased run for a runtime."""
|
||||
now = _utc_now()
|
||||
lease_expires_at = now + datetime.timedelta(seconds=max(int(lease_seconds), 1))
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(AgentRun).where(
|
||||
sqlalchemy.or_(
|
||||
AgentRun.status == 'queued',
|
||||
sqlalchemy.and_(
|
||||
AgentRun.status == 'claimed',
|
||||
AgentRun.claim_lease_expires_at.is_not(None),
|
||||
AgentRun.claim_lease_expires_at <= now,
|
||||
),
|
||||
),
|
||||
sqlalchemy.or_(
|
||||
AgentRun.requested_runtime_id.is_(None),
|
||||
AgentRun.requested_runtime_id == runtime_id,
|
||||
),
|
||||
)
|
||||
if queue_name is not None:
|
||||
query = query.where(AgentRun.queue_name == queue_name)
|
||||
if runner_ids:
|
||||
query = query.where(AgentRun.runner_id.in_(runner_ids))
|
||||
|
||||
query = query.order_by(AgentRun.priority.desc(), AgentRun.id.asc()).limit(1).with_for_update(
|
||||
skip_locked=True
|
||||
)
|
||||
result = await session.execute(query)
|
||||
run = result.scalars().first()
|
||||
if run is None:
|
||||
return None
|
||||
|
||||
run.status = 'claimed'
|
||||
run.claimed_by_runtime_id = runtime_id
|
||||
run.claim_token = uuid.uuid4().hex
|
||||
run.claim_lease_expires_at = lease_expires_at
|
||||
run.dispatch_attempts = (run.dispatch_attempts or 0) + 1
|
||||
run.last_claimed_at = now
|
||||
run.updated_at = now
|
||||
await session.commit()
|
||||
return self._run_to_dict(run)
|
||||
|
||||
async def renew_claim(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
claim_token: str,
|
||||
lease_seconds: int = 60,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Extend a current claim lease if the token still matches."""
|
||||
now = _utc_now()
|
||||
async with self._session_factory() as session:
|
||||
run = await self._get_run_row(session, run_id)
|
||||
if run is None or run.status != 'claimed' or run.claim_token != claim_token:
|
||||
return None
|
||||
|
||||
run.claim_lease_expires_at = now + datetime.timedelta(seconds=max(int(lease_seconds), 1))
|
||||
run.updated_at = now
|
||||
await session.commit()
|
||||
return self._run_to_dict(run)
|
||||
|
||||
async def release_claim(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
claim_token: str,
|
||||
status: str = 'queued',
|
||||
status_reason: str | None = None,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Release a current claim lease if the token still matches."""
|
||||
now = _utc_now()
|
||||
async with self._session_factory() as session:
|
||||
run = await self._get_run_row(session, run_id)
|
||||
if run is None or run.status != 'claimed' or run.claim_token != claim_token:
|
||||
return None
|
||||
|
||||
run.status = status
|
||||
run.status_reason = status_reason
|
||||
run.claimed_by_runtime_id = None
|
||||
run.claim_token = None
|
||||
run.claim_lease_expires_at = None
|
||||
run.updated_at = now
|
||||
if status in TERMINAL_STATUSES:
|
||||
run.finished_at = run.finished_at or now
|
||||
await session.commit()
|
||||
return self._run_to_dict(run)
|
||||
|
||||
async def append_event(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
sequence: int,
|
||||
event_type: str,
|
||||
data: dict[str, typing.Any] | None = None,
|
||||
usage: dict[str, typing.Any] | None = None,
|
||||
source: str = 'runner',
|
||||
artifact_refs: list[dict[str, typing.Any]] | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Append one run result event.
|
||||
|
||||
If the same run_id + sequence already exists, the existing row is
|
||||
returned. This supports retrying append calls idempotently.
|
||||
"""
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(AgentRunEvent).where(
|
||||
AgentRunEvent.run_id == run_id,
|
||||
AgentRunEvent.sequence == sequence,
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
if existing is not None:
|
||||
return self._event_to_dict(existing)
|
||||
|
||||
row = AgentRunEvent(
|
||||
run_id=run_id,
|
||||
sequence=sequence,
|
||||
type=event_type,
|
||||
data_json=_json_dumps(data or {}),
|
||||
usage_json=_json_dumps(usage),
|
||||
created_at=_utc_now(),
|
||||
source=source,
|
||||
artifact_refs_json=_json_dumps(artifact_refs or []),
|
||||
metadata_json=_json_dumps(metadata),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
return self._event_to_dict(row)
|
||||
|
||||
async def finalize_run(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
status: str,
|
||||
status_reason: str | None = None,
|
||||
usage: dict[str, typing.Any] | None = None,
|
||||
cost: dict[str, typing.Any] | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Update a run to a terminal or current status."""
|
||||
now = _utc_now()
|
||||
async with self._session_factory() as session:
|
||||
run = await self._get_run_row(session, run_id)
|
||||
if run is None:
|
||||
return None
|
||||
|
||||
run.status = status
|
||||
run.status_reason = status_reason
|
||||
run.updated_at = now
|
||||
if status in TERMINAL_STATUSES:
|
||||
run.finished_at = run.finished_at or now
|
||||
if usage is not None:
|
||||
run.usage_json = _json_dumps(usage)
|
||||
if cost is not None:
|
||||
run.cost_json = _json_dumps(cost)
|
||||
if metadata is not None:
|
||||
existing_metadata = _json_loads(run.metadata_json, {})
|
||||
if isinstance(existing_metadata, dict):
|
||||
existing_metadata.update(metadata)
|
||||
run.metadata_json = _json_dumps(existing_metadata)
|
||||
else:
|
||||
run.metadata_json = _json_dumps(metadata)
|
||||
await session.commit()
|
||||
return self._run_to_dict(run)
|
||||
|
||||
async def request_cancel(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
status_reason: str | None = None,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Record a cancellation request."""
|
||||
now = _utc_now()
|
||||
async with self._session_factory() as session:
|
||||
run = await self._get_run_row(session, run_id)
|
||||
if run is None:
|
||||
return None
|
||||
run.cancel_requested_at = now
|
||||
run.updated_at = now
|
||||
run.status_reason = status_reason or run.status_reason
|
||||
await session.commit()
|
||||
return self._run_to_dict(run)
|
||||
|
||||
async def get_run(self, run_id: str) -> dict[str, typing.Any] | None:
|
||||
"""Get one run by run_id."""
|
||||
async with self._session_factory() as session:
|
||||
row = await self._get_run_row(session, run_id)
|
||||
return self._run_to_dict(row) if row is not None else None
|
||||
|
||||
async def register_runtime(
|
||||
self,
|
||||
*,
|
||||
runtime_id: str,
|
||||
status: str = 'online',
|
||||
display_name: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
version: str | None = None,
|
||||
capabilities: dict[str, typing.Any] | None = None,
|
||||
labels: dict[str, typing.Any] | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
heartbeat_deadline_seconds: int = 60,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Create or update a runtime registry row and record a heartbeat."""
|
||||
now = _utc_now()
|
||||
async with self._session_factory() as session:
|
||||
runtime = await self._get_runtime_row(session, runtime_id)
|
||||
if runtime is None:
|
||||
runtime = AgentRuntime(runtime_id=runtime_id, created_at=now)
|
||||
session.add(runtime)
|
||||
|
||||
runtime.status = status
|
||||
runtime.display_name = display_name
|
||||
runtime.endpoint = endpoint
|
||||
runtime.version = version
|
||||
runtime.capabilities_json = _json_dumps(capabilities or {})
|
||||
runtime.labels_json = _json_dumps(labels or {})
|
||||
runtime.metadata_json = _json_dumps(metadata or {})
|
||||
runtime.last_heartbeat_at = now
|
||||
runtime.heartbeat_deadline_at = now + datetime.timedelta(seconds=max(int(heartbeat_deadline_seconds), 1))
|
||||
runtime.updated_at = now
|
||||
await session.commit()
|
||||
return self._runtime_to_dict(runtime)
|
||||
|
||||
async def heartbeat_runtime(
|
||||
self,
|
||||
*,
|
||||
runtime_id: str,
|
||||
status: str = 'online',
|
||||
heartbeat_deadline_seconds: int = 60,
|
||||
capabilities: dict[str, typing.Any] | None = None,
|
||||
labels: dict[str, typing.Any] | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Refresh a runtime heartbeat."""
|
||||
now = _utc_now()
|
||||
async with self._session_factory() as session:
|
||||
runtime = await self._get_runtime_row(session, runtime_id)
|
||||
if runtime is None:
|
||||
return None
|
||||
|
||||
runtime.status = status
|
||||
runtime.last_heartbeat_at = now
|
||||
runtime.heartbeat_deadline_at = now + datetime.timedelta(seconds=max(int(heartbeat_deadline_seconds), 1))
|
||||
runtime.updated_at = now
|
||||
if capabilities is not None:
|
||||
runtime.capabilities_json = _json_dumps(capabilities)
|
||||
if labels is not None:
|
||||
runtime.labels_json = _json_dumps(labels)
|
||||
if metadata is not None:
|
||||
existing_metadata = _json_loads(runtime.metadata_json, {})
|
||||
if isinstance(existing_metadata, dict):
|
||||
existing_metadata.update(metadata)
|
||||
runtime.metadata_json = _json_dumps(existing_metadata)
|
||||
else:
|
||||
runtime.metadata_json = _json_dumps(metadata)
|
||||
await session.commit()
|
||||
return self._runtime_to_dict(runtime)
|
||||
|
||||
async def get_runtime(self, runtime_id: str) -> dict[str, typing.Any] | None:
|
||||
"""Get one runtime by runtime_id."""
|
||||
async with self._session_factory() as session:
|
||||
row = await self._get_runtime_row(session, runtime_id)
|
||||
return self._runtime_to_dict(row) if row is not None else None
|
||||
|
||||
async def list_runtimes(
|
||||
self,
|
||||
*,
|
||||
statuses: list[str] | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict[str, typing.Any]]:
|
||||
"""List runtime registry rows."""
|
||||
limit = min(max(int(limit), 1), 500)
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(AgentRuntime)
|
||||
if statuses:
|
||||
query = query.where(AgentRuntime.status.in_(statuses))
|
||||
query = query.order_by(AgentRuntime.id.asc()).limit(limit)
|
||||
result = await session.execute(query)
|
||||
return [self._runtime_to_dict(row) for row in result.scalars().all()]
|
||||
|
||||
async def mark_stale_runtimes(
|
||||
self,
|
||||
*,
|
||||
now: datetime.datetime | None = None,
|
||||
stale_status: str = 'stale',
|
||||
) -> list[dict[str, typing.Any]]:
|
||||
"""Mark runtimes stale when their heartbeat deadline has passed."""
|
||||
current_time = now or _utc_now()
|
||||
if current_time.tzinfo is None:
|
||||
current_time = current_time.replace(tzinfo=UTC)
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(AgentRuntime).where(
|
||||
AgentRuntime.heartbeat_deadline_at.is_not(None),
|
||||
AgentRuntime.heartbeat_deadline_at < current_time,
|
||||
AgentRuntime.status != stale_status,
|
||||
)
|
||||
)
|
||||
runtimes = result.scalars().all()
|
||||
for runtime in runtimes:
|
||||
runtime.status = stale_status
|
||||
runtime.updated_at = current_time
|
||||
await session.commit()
|
||||
return [self._runtime_to_dict(runtime) for runtime in runtimes]
|
||||
|
||||
async def list_runs(
|
||||
self,
|
||||
*,
|
||||
conversation_id: str | None = None,
|
||||
statuses: list[str] | None = None,
|
||||
before_id: int | None = None,
|
||||
limit: int = 50,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
strict_thread: bool = False,
|
||||
) -> tuple[list[dict[str, typing.Any]], int | None, bool]:
|
||||
"""Page runs by scope."""
|
||||
limit = min(max(int(limit), 1), 100)
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(AgentRun)
|
||||
if conversation_id is not None:
|
||||
query = query.where(AgentRun.conversation_id == conversation_id)
|
||||
if statuses:
|
||||
query = query.where(AgentRun.status.in_(statuses))
|
||||
if before_id is not None:
|
||||
query = query.where(AgentRun.id < before_id)
|
||||
query = self._apply_scope_filters(query, bot_id, workspace_id, thread_id, strict_thread)
|
||||
query = query.order_by(AgentRun.id.desc()).limit(limit + 1)
|
||||
|
||||
result = await session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
items = [self._run_to_dict(row) for row in rows[:limit]]
|
||||
has_more = len(rows) > limit
|
||||
next_cursor = items[-1]['id'] if items and has_more else None
|
||||
return items, next_cursor, has_more
|
||||
|
||||
async def page_run_events(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
before_sequence: int | None = None,
|
||||
after_sequence: int | None = None,
|
||||
limit: int = 50,
|
||||
direction: str = 'forward',
|
||||
) -> tuple[list[dict[str, typing.Any]], int | None, int | None, bool]:
|
||||
"""Page result events for one run."""
|
||||
limit = min(max(int(limit), 1), 100)
|
||||
direction = direction if direction in {'forward', 'backward'} else 'forward'
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(AgentRunEvent).where(AgentRunEvent.run_id == run_id)
|
||||
if before_sequence is not None:
|
||||
query = query.where(AgentRunEvent.sequence < before_sequence)
|
||||
if after_sequence is not None:
|
||||
query = query.where(AgentRunEvent.sequence > after_sequence)
|
||||
|
||||
if direction == 'backward':
|
||||
query = query.order_by(AgentRunEvent.sequence.desc())
|
||||
else:
|
||||
query = query.order_by(AgentRunEvent.sequence.asc())
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
items = [self._event_to_dict(row) for row in rows[:limit]]
|
||||
has_more = len(rows) > limit
|
||||
|
||||
if direction == 'backward':
|
||||
next_cursor = items[-1]['sequence'] if items and has_more else None
|
||||
prev_cursor = items[0]['sequence'] if items else None
|
||||
else:
|
||||
next_cursor = items[-1]['sequence'] if items and has_more else None
|
||||
prev_cursor = items[0]['sequence'] if items else None
|
||||
return items, next_cursor, prev_cursor, has_more
|
||||
|
||||
async def _get_run_row(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
run_id: str,
|
||||
) -> AgentRun | None:
|
||||
result = await session.execute(sqlalchemy.select(AgentRun).where(AgentRun.run_id == run_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async def _get_runtime_row(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
runtime_id: str,
|
||||
) -> AgentRuntime | None:
|
||||
result = await session.execute(sqlalchemy.select(AgentRuntime).where(AgentRuntime.runtime_id == runtime_id))
|
||||
return result.scalars().first()
|
||||
|
||||
def _apply_scope_filters(
|
||||
self,
|
||||
query: typing.Any,
|
||||
bot_id: str | None,
|
||||
workspace_id: str | None,
|
||||
thread_id: str | None,
|
||||
strict_thread: bool,
|
||||
) -> typing.Any:
|
||||
if bot_id is not None:
|
||||
query = query.where(AgentRun.bot_id == bot_id)
|
||||
if workspace_id is not None:
|
||||
query = query.where(AgentRun.workspace_id == workspace_id)
|
||||
if strict_thread:
|
||||
if thread_id is None:
|
||||
query = query.where(AgentRun.thread_id.is_(None))
|
||||
else:
|
||||
query = query.where(AgentRun.thread_id == thread_id)
|
||||
return query
|
||||
|
||||
def _run_to_dict(self, row: AgentRun) -> dict[str, typing.Any]:
|
||||
return {
|
||||
'id': row.id,
|
||||
'run_id': row.run_id,
|
||||
'event_id': row.event_id,
|
||||
'agent_id': row.agent_id,
|
||||
'binding_id': row.binding_id,
|
||||
'runner_id': row.runner_id,
|
||||
'conversation_id': row.conversation_id,
|
||||
'thread_id': row.thread_id,
|
||||
'workspace_id': row.workspace_id,
|
||||
'bot_id': row.bot_id,
|
||||
'status': row.status,
|
||||
'status_reason': row.status_reason,
|
||||
'queue_name': row.queue_name,
|
||||
'priority': row.priority,
|
||||
'requested_runtime_id': row.requested_runtime_id,
|
||||
'claimed_by_runtime_id': row.claimed_by_runtime_id,
|
||||
'claim_token': row.claim_token,
|
||||
'claim_lease_expires_at': _datetime_to_epoch(row.claim_lease_expires_at),
|
||||
'dispatch_attempts': row.dispatch_attempts,
|
||||
'last_claimed_at': _datetime_to_epoch(row.last_claimed_at),
|
||||
'created_at': _datetime_to_epoch(row.created_at),
|
||||
'started_at': _datetime_to_epoch(row.started_at),
|
||||
'finished_at': _datetime_to_epoch(row.finished_at),
|
||||
'updated_at': _datetime_to_epoch(row.updated_at),
|
||||
'deadline_at': _datetime_to_epoch(row.deadline_at),
|
||||
'cancel_requested_at': _datetime_to_epoch(row.cancel_requested_at),
|
||||
'usage': _json_loads(row.usage_json, None),
|
||||
'cost': _json_loads(row.cost_json, None),
|
||||
'metadata': _json_loads(row.metadata_json, {}),
|
||||
}
|
||||
|
||||
def _runtime_to_dict(self, row: AgentRuntime) -> dict[str, typing.Any]:
|
||||
return {
|
||||
'id': row.id,
|
||||
'runtime_id': row.runtime_id,
|
||||
'status': row.status,
|
||||
'display_name': row.display_name,
|
||||
'endpoint': row.endpoint,
|
||||
'version': row.version,
|
||||
'capabilities': _json_loads(row.capabilities_json, {}),
|
||||
'labels': _json_loads(row.labels_json, {}),
|
||||
'metadata': _json_loads(row.metadata_json, {}),
|
||||
'last_heartbeat_at': _datetime_to_epoch(row.last_heartbeat_at),
|
||||
'heartbeat_deadline_at': _datetime_to_epoch(row.heartbeat_deadline_at),
|
||||
'created_at': _datetime_to_epoch(row.created_at),
|
||||
'updated_at': _datetime_to_epoch(row.updated_at),
|
||||
}
|
||||
|
||||
def _event_to_dict(self, row: AgentRunEvent) -> dict[str, typing.Any]:
|
||||
return {
|
||||
'id': row.id,
|
||||
'run_id': row.run_id,
|
||||
'sequence': row.sequence,
|
||||
'type': row.type,
|
||||
'data': _json_loads(row.data_json, {}),
|
||||
'usage': _json_loads(row.usage_json, None),
|
||||
'created_at': _datetime_to_epoch(row.created_at),
|
||||
'source': row.source,
|
||||
'artifact_refs': _json_loads(row.artifact_refs_json, []),
|
||||
'metadata': _json_loads(row.metadata_json, {}),
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
"""Agent run ledger persistence entities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
class AgentRun(Base):
|
||||
"""AgentRun stores Host-owned execution lifecycle facts."""
|
||||
|
||||
__tablename__ = 'agent_run'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
"""Auto-increment ID for pagination."""
|
||||
|
||||
run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True)
|
||||
"""Unique AgentRunner run identifier."""
|
||||
|
||||
event_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Input event that triggered this run."""
|
||||
|
||||
agent_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Future Host-owned agent identifier."""
|
||||
|
||||
binding_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Binding that selected this runner."""
|
||||
|
||||
runner_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
"""Runner descriptor ID."""
|
||||
|
||||
conversation_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Conversation this run belongs to."""
|
||||
|
||||
thread_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Thread this run belongs to."""
|
||||
|
||||
workspace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Workspace this run belongs to."""
|
||||
|
||||
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Bot UUID this run belongs to."""
|
||||
|
||||
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True)
|
||||
"""Run lifecycle status."""
|
||||
|
||||
status_reason = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Human-readable terminal or current status reason."""
|
||||
|
||||
queue_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Host queue name this run is waiting in."""
|
||||
|
||||
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
"""Higher values are claimed before lower values within a queue."""
|
||||
|
||||
requested_runtime_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Specific runtime requested by the producer, if any."""
|
||||
|
||||
claimed_by_runtime_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Runtime that currently owns the claim lease."""
|
||||
|
||||
claim_token = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Opaque token required to renew or release the current claim."""
|
||||
|
||||
claim_lease_expires_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True, index=True)
|
||||
"""When the current claim lease expires."""
|
||||
|
||||
dispatch_attempts = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
"""Number of times this run has been claimed for dispatch."""
|
||||
|
||||
last_claimed_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
"""When this run was last claimed."""
|
||||
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow)
|
||||
"""When the run record was created."""
|
||||
|
||||
started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
"""When execution started."""
|
||||
|
||||
finished_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
"""When execution reached a terminal status."""
|
||||
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
)
|
||||
"""When the run record was last updated."""
|
||||
|
||||
deadline_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
"""Execution deadline if one was assigned."""
|
||||
|
||||
cancel_requested_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
"""When cancellation was requested."""
|
||||
|
||||
usage_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Final or latest aggregate token usage JSON."""
|
||||
|
||||
cost_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Host-calculated cost JSON, if available."""
|
||||
|
||||
authorization_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Run-scoped authorization snapshot JSON."""
|
||||
|
||||
metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Additional metadata JSON."""
|
||||
|
||||
__table_args__ = (
|
||||
sqlalchemy.Index(
|
||||
'ix_agent_run_scope_status', 'bot_id', 'workspace_id', 'conversation_id', 'thread_id', 'status'
|
||||
),
|
||||
sqlalchemy.Index('ix_agent_run_runner_status', 'runner_id', 'status'),
|
||||
sqlalchemy.Index('ix_agent_run_queue_claim', 'queue_name', 'status', 'priority', 'id'),
|
||||
)
|
||||
|
||||
|
||||
class AgentRuntime(Base):
|
||||
"""AgentRuntime stores Host-owned runtime heartbeat registry facts."""
|
||||
|
||||
__tablename__ = 'agent_runtime'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
"""Auto-increment ID."""
|
||||
|
||||
runtime_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True)
|
||||
"""Unique runtime or daemon identifier."""
|
||||
|
||||
status = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, index=True)
|
||||
"""Runtime lifecycle status."""
|
||||
|
||||
display_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Human-readable runtime display name."""
|
||||
|
||||
endpoint = sqlalchemy.Column(sqlalchemy.String(1024), nullable=True)
|
||||
"""Runtime endpoint, if it exposes one."""
|
||||
|
||||
version = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Runtime version string."""
|
||||
|
||||
capabilities_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Runtime capabilities JSON."""
|
||||
|
||||
labels_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Runtime labels JSON."""
|
||||
|
||||
metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Additional metadata JSON."""
|
||||
|
||||
last_heartbeat_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True, index=True)
|
||||
"""When the runtime last sent a heartbeat."""
|
||||
|
||||
heartbeat_deadline_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True, index=True)
|
||||
"""When the runtime should be considered stale."""
|
||||
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow)
|
||||
"""When the runtime record was created."""
|
||||
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
|
||||
)
|
||||
"""When the runtime record was last updated."""
|
||||
|
||||
|
||||
class AgentRunEvent(Base):
|
||||
"""AgentRunEvent stores one result event emitted by a run."""
|
||||
|
||||
__tablename__ = 'agent_run_event'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
"""Auto-increment ID."""
|
||||
|
||||
run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, index=True)
|
||||
"""Run that produced this event."""
|
||||
|
||||
sequence = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
|
||||
"""Monotonic sequence inside the run."""
|
||||
|
||||
type = sqlalchemy.Column(sqlalchemy.String(100), nullable=False, index=True)
|
||||
"""Result event type."""
|
||||
|
||||
data_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Result event payload JSON."""
|
||||
|
||||
usage_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Token usage JSON for this event, if provided."""
|
||||
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow)
|
||||
"""When this event was persisted."""
|
||||
|
||||
source = sqlalchemy.Column(sqlalchemy.String(50), nullable=True)
|
||||
"""Source that appended the event."""
|
||||
|
||||
artifact_refs_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Artifact references associated with this event."""
|
||||
|
||||
metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Additional metadata JSON."""
|
||||
|
||||
__table_args__ = (
|
||||
sqlalchemy.UniqueConstraint('run_id', 'sequence', name='uq_agent_run_event_run_sequence'),
|
||||
sqlalchemy.Index('ix_agent_run_event_run_sequence', 'run_id', 'sequence'),
|
||||
)
|
||||
@@ -16,6 +16,7 @@ from langbot.pkg.entity.persistence.base import Base
|
||||
# Import all ORM models so they are registered with Base.metadata
|
||||
# This is required for autogenerate to detect model changes
|
||||
from langbot.pkg.entity.persistence import (
|
||||
agent_run, # noqa: F401
|
||||
agent_runner_state, # noqa: F401
|
||||
apikey, # noqa: F401
|
||||
artifact, # noqa: F401
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
"""add agent run ledger
|
||||
|
||||
Revision ID: 8d3a1f2c4b6e
|
||||
Revises: 7b2c1d9e4f30
|
||||
Create Date: 2026-06-15
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = '8d3a1f2c4b6e'
|
||||
down_revision = '7b2c1d9e4f30'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _table_exists(table_name: str) -> bool:
|
||||
return table_name in sa.inspect(op.get_bind()).get_table_names()
|
||||
|
||||
|
||||
def _index_exists(table_name: str, index_name: str) -> bool:
|
||||
return index_name in {index['name'] for index in sa.inspect(op.get_bind()).get_indexes(table_name)}
|
||||
|
||||
|
||||
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||
if not _table_exists(table_name):
|
||||
return False
|
||||
return column_name in {column['name'] for column in sa.inspect(op.get_bind()).get_columns(table_name)}
|
||||
|
||||
|
||||
def _add_column_if_missing(table_name: str, column: sa.Column) -> None:
|
||||
if not _table_exists(table_name) or _column_exists(table_name, column.name):
|
||||
return
|
||||
with op.batch_alter_table(table_name, schema=None) as batch_op:
|
||||
batch_op.add_column(column)
|
||||
|
||||
|
||||
def _create_index_if_missing(table_name: str, index_name: str, columns: list[str], *, unique: bool = False) -> None:
|
||||
if not _table_exists(table_name) or _index_exists(table_name, index_name):
|
||||
return
|
||||
existing_columns = {column['name'] for column in sa.inspect(op.get_bind()).get_columns(table_name)}
|
||||
if not set(columns).issubset(existing_columns):
|
||||
return
|
||||
with op.batch_alter_table(table_name, schema=None) as batch_op:
|
||||
batch_op.create_index(index_name, columns, unique=unique)
|
||||
|
||||
|
||||
def _drop_index_if_exists(table_name: str, index_name: str) -> None:
|
||||
if not _table_exists(table_name) or not _index_exists(table_name, index_name):
|
||||
return
|
||||
with op.batch_alter_table(table_name, schema=None) as batch_op:
|
||||
batch_op.drop_index(index_name)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if not _table_exists('agent_run'):
|
||||
op.create_table(
|
||||
'agent_run',
|
||||
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column('run_id', sa.String(255), nullable=False, unique=True),
|
||||
sa.Column('event_id', sa.String(255), nullable=True),
|
||||
sa.Column('agent_id', sa.String(255), nullable=True),
|
||||
sa.Column('binding_id', sa.String(255), nullable=True),
|
||||
sa.Column('runner_id', sa.String(255), nullable=False),
|
||||
sa.Column('conversation_id', sa.String(255), nullable=True),
|
||||
sa.Column('thread_id', sa.String(255), nullable=True),
|
||||
sa.Column('workspace_id', sa.String(255), nullable=True),
|
||||
sa.Column('bot_id', sa.String(255), nullable=True),
|
||||
sa.Column('status', sa.String(50), nullable=False),
|
||||
sa.Column('status_reason', sa.Text(), nullable=True),
|
||||
sa.Column('queue_name', sa.String(255), nullable=True),
|
||||
sa.Column('priority', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('requested_runtime_id', sa.String(255), nullable=True),
|
||||
sa.Column('claimed_by_runtime_id', sa.String(255), nullable=True),
|
||||
sa.Column('claim_token', sa.String(255), nullable=True),
|
||||
sa.Column('claim_lease_expires_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('dispatch_attempts', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('last_claimed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('finished_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')),
|
||||
sa.Column('deadline_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('cancel_requested_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('usage_json', sa.Text(), nullable=True),
|
||||
sa.Column('cost_json', sa.Text(), nullable=True),
|
||||
sa.Column('authorization_json', sa.Text(), nullable=True),
|
||||
sa.Column('metadata_json', sa.Text(), nullable=True),
|
||||
)
|
||||
else:
|
||||
_add_column_if_missing('agent_run', sa.Column('queue_name', sa.String(255), nullable=True))
|
||||
_add_column_if_missing(
|
||||
'agent_run', sa.Column('priority', sa.Integer(), nullable=False, server_default='0')
|
||||
)
|
||||
_add_column_if_missing('agent_run', sa.Column('requested_runtime_id', sa.String(255), nullable=True))
|
||||
_add_column_if_missing('agent_run', sa.Column('claimed_by_runtime_id', sa.String(255), nullable=True))
|
||||
_add_column_if_missing('agent_run', sa.Column('claim_token', sa.String(255), nullable=True))
|
||||
_add_column_if_missing('agent_run', sa.Column('claim_lease_expires_at', sa.DateTime(), nullable=True))
|
||||
_add_column_if_missing(
|
||||
'agent_run', sa.Column('dispatch_attempts', sa.Integer(), nullable=False, server_default='0')
|
||||
)
|
||||
_add_column_if_missing('agent_run', sa.Column('last_claimed_at', sa.DateTime(), nullable=True))
|
||||
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_run_id', ['run_id'], unique=True)
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_event_id', ['event_id'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_binding_id', ['binding_id'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_runner_id', ['runner_id'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_conversation_id', ['conversation_id'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_bot_id', ['bot_id'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_status', ['status'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_queue_name', ['queue_name'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_requested_runtime_id', ['requested_runtime_id'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_claimed_by_runtime_id', ['claimed_by_runtime_id'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_claim_token', ['claim_token'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_claim_lease_expires_at', ['claim_lease_expires_at'])
|
||||
_create_index_if_missing(
|
||||
'agent_run',
|
||||
'ix_agent_run_scope_status',
|
||||
['bot_id', 'workspace_id', 'conversation_id', 'thread_id', 'status'],
|
||||
)
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_runner_status', ['runner_id', 'status'])
|
||||
_create_index_if_missing('agent_run', 'ix_agent_run_queue_claim', ['queue_name', 'status', 'priority', 'id'])
|
||||
|
||||
if not _table_exists('agent_run_event'):
|
||||
op.create_table(
|
||||
'agent_run_event',
|
||||
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column('run_id', sa.String(255), nullable=False),
|
||||
sa.Column('sequence', sa.Integer(), nullable=False),
|
||||
sa.Column('type', sa.String(100), nullable=False),
|
||||
sa.Column('data_json', sa.Text(), nullable=True),
|
||||
sa.Column('usage_json', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')),
|
||||
sa.Column('source', sa.String(50), nullable=True),
|
||||
sa.Column('artifact_refs_json', sa.Text(), nullable=True),
|
||||
sa.Column('metadata_json', sa.Text(), nullable=True),
|
||||
sa.UniqueConstraint('run_id', 'sequence', name='uq_agent_run_event_run_sequence'),
|
||||
)
|
||||
|
||||
_create_index_if_missing('agent_run_event', 'ix_agent_run_event_run_id', ['run_id'])
|
||||
_create_index_if_missing('agent_run_event', 'ix_agent_run_event_type', ['type'])
|
||||
_create_index_if_missing(
|
||||
'agent_run_event',
|
||||
'ix_agent_run_event_run_sequence',
|
||||
['run_id', 'sequence'],
|
||||
)
|
||||
|
||||
if not _table_exists('agent_runtime'):
|
||||
op.create_table(
|
||||
'agent_runtime',
|
||||
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column('runtime_id', sa.String(255), nullable=False, unique=True),
|
||||
sa.Column('status', sa.String(50), nullable=False),
|
||||
sa.Column('display_name', sa.String(255), nullable=True),
|
||||
sa.Column('endpoint', sa.String(1024), nullable=True),
|
||||
sa.Column('version', sa.String(255), nullable=True),
|
||||
sa.Column('capabilities_json', sa.Text(), nullable=True),
|
||||
sa.Column('labels_json', sa.Text(), nullable=True),
|
||||
sa.Column('metadata_json', sa.Text(), nullable=True),
|
||||
sa.Column('last_heartbeat_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('heartbeat_deadline_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')),
|
||||
)
|
||||
|
||||
_create_index_if_missing('agent_runtime', 'ix_agent_runtime_runtime_id', ['runtime_id'], unique=True)
|
||||
_create_index_if_missing('agent_runtime', 'ix_agent_runtime_status', ['status'])
|
||||
_create_index_if_missing('agent_runtime', 'ix_agent_runtime_last_heartbeat_at', ['last_heartbeat_at'])
|
||||
_create_index_if_missing('agent_runtime', 'ix_agent_runtime_heartbeat_deadline_at', ['heartbeat_deadline_at'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
_drop_index_if_exists('agent_runtime', 'ix_agent_runtime_heartbeat_deadline_at')
|
||||
_drop_index_if_exists('agent_runtime', 'ix_agent_runtime_last_heartbeat_at')
|
||||
_drop_index_if_exists('agent_runtime', 'ix_agent_runtime_status')
|
||||
_drop_index_if_exists('agent_runtime', 'ix_agent_runtime_runtime_id')
|
||||
if _table_exists('agent_runtime'):
|
||||
op.drop_table('agent_runtime')
|
||||
|
||||
_drop_index_if_exists('agent_run_event', 'ix_agent_run_event_run_sequence')
|
||||
_drop_index_if_exists('agent_run_event', 'ix_agent_run_event_type')
|
||||
_drop_index_if_exists('agent_run_event', 'ix_agent_run_event_run_id')
|
||||
if _table_exists('agent_run_event'):
|
||||
op.drop_table('agent_run_event')
|
||||
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_queue_claim')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_claim_lease_expires_at')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_claim_token')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_claimed_by_runtime_id')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_requested_runtime_id')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_queue_name')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_runner_status')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_scope_status')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_status')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_bot_id')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_conversation_id')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_runner_id')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_binding_id')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_event_id')
|
||||
_drop_index_if_exists('agent_run', 'ix_agent_run_run_id')
|
||||
if _table_exists('agent_run'):
|
||||
op.drop_table('agent_run')
|
||||
+804
-100
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,490 @@
|
||||
"""Tests for AgentRunner run ledger pull API authorization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.agent.runner.run_ledger_store import RunLedgerStore
|
||||
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry
|
||||
from langbot.pkg.entity.persistence import agent_run as agent_run_model
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.plugin.handler import RuntimeConnectionHandler
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.run_ledger import (
|
||||
AgentRun,
|
||||
AgentRunEvent,
|
||||
RunEventPage,
|
||||
RunPage,
|
||||
)
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
from .conftest import make_resources
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
pass
|
||||
|
||||
|
||||
class FakeApplication:
|
||||
def __init__(self, db_engine):
|
||||
self.logger = MagicMock()
|
||||
self.persistence_mgr = MagicMock()
|
||||
self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_registry(monkeypatch):
|
||||
registry = AgentRunSessionRegistry()
|
||||
monkeypatch.setattr(
|
||||
'langbot.pkg.agent.runner.session_registry._global_registry',
|
||||
registry,
|
||||
)
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
|
||||
assert agent_run_model.AgentRun.__tablename__ == 'agent_run'
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def _handler(db_engine):
|
||||
async def fake_disconnect():
|
||||
return True
|
||||
|
||||
fake_app = FakeApplication(db_engine)
|
||||
return RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||
|
||||
|
||||
async def _register_session(
|
||||
session_registry,
|
||||
*,
|
||||
run_id='run_1',
|
||||
conversation_id='conv_1',
|
||||
available_apis=None,
|
||||
):
|
||||
await session_registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=None,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
conversation_id=conversation_id,
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id=None,
|
||||
available_apis=available_apis or {},
|
||||
)
|
||||
|
||||
|
||||
async def _create_run(
|
||||
db_engine,
|
||||
*,
|
||||
run_id='run_1',
|
||||
conversation_id='conv_1',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id=None,
|
||||
plugin_identity='test/runner',
|
||||
available_apis=None,
|
||||
):
|
||||
store = RunLedgerStore(db_engine)
|
||||
await store.create_run(
|
||||
run_id=run_id,
|
||||
event_id='evt_1',
|
||||
binding_id='binding_1',
|
||||
runner_id='plugin:test/runner/default',
|
||||
conversation_id=conversation_id,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
thread_id=thread_id,
|
||||
authorization={
|
||||
'runner_id': 'plugin:test/runner/default',
|
||||
'binding_id': 'binding_1',
|
||||
'plugin_identity': plugin_identity,
|
||||
'resources': make_resources(),
|
||||
'available_apis': available_apis or {},
|
||||
'conversation_id': conversation_id,
|
||||
'bot_id': bot_id,
|
||||
'workspace_id': workspace_id,
|
||||
'thread_id': thread_id,
|
||||
'state_policy': {'enable_state': True, 'state_scopes': ['conversation', 'actor']},
|
||||
'state_context': {},
|
||||
},
|
||||
)
|
||||
await store.append_event(
|
||||
run_id=run_id,
|
||||
sequence=1,
|
||||
event_type='message.completed',
|
||||
data={'message': {'role': 'assistant', 'content': 'ok'}},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_get_returns_current_run(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'run_get': True})
|
||||
await _create_run(db_engine)
|
||||
handler = _handler(db_engine)
|
||||
run_get = handler.actions[PluginToRuntimeAction.RUN_GET.value]
|
||||
|
||||
result = await run_get(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code == 0
|
||||
run = AgentRun.model_validate(result.data)
|
||||
assert run.run_id == 'run_1'
|
||||
assert run.status == 'running'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_list_rejects_cross_conversation(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'run_list': True})
|
||||
handler = _handler(db_engine)
|
||||
run_list = handler.actions[PluginToRuntimeAction.RUN_LIST.value]
|
||||
|
||||
result = await run_list(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'conversation_id': 'conv_other',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not accessible' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_list_returns_scoped_runs(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'run_list': True})
|
||||
await _create_run(db_engine)
|
||||
await _create_run(db_engine, run_id='run_other', conversation_id='conv_other')
|
||||
handler = _handler(db_engine)
|
||||
run_list = handler.actions[PluginToRuntimeAction.RUN_LIST.value]
|
||||
|
||||
result = await run_list(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code == 0
|
||||
page = RunPage.model_validate(result.data)
|
||||
assert [run.run_id for run in page.items] == ['run_1']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_events_page_returns_events(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'run_events_page': True})
|
||||
await _create_run(db_engine)
|
||||
handler = _handler(db_engine)
|
||||
run_events_page = handler.actions[PluginToRuntimeAction.RUN_EVENTS_PAGE.value]
|
||||
|
||||
result = await run_events_page(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code == 0
|
||||
page = RunEventPage.model_validate(result.data)
|
||||
assert [item.sequence for item in page.items] == [1]
|
||||
assert page.items[0].type == 'message.completed'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_get_uses_persistent_authorization_after_session_expired(db_engine):
|
||||
await _create_run(db_engine, available_apis={'run_get': True})
|
||||
handler = _handler(db_engine)
|
||||
run_get = handler.actions[PluginToRuntimeAction.RUN_GET.value]
|
||||
|
||||
result = await run_get(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code == 0
|
||||
run = AgentRun.model_validate(result.data)
|
||||
assert run.run_id == 'run_1'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistent_run_get_rejects_cross_scope(db_engine):
|
||||
await _create_run(db_engine, available_apis={'run_get': True})
|
||||
await _create_run(
|
||||
db_engine,
|
||||
run_id='run_other',
|
||||
conversation_id='conv_other',
|
||||
available_apis={'run_get': True},
|
||||
)
|
||||
handler = _handler(db_engine)
|
||||
run_get = handler.actions[PluginToRuntimeAction.RUN_GET.value]
|
||||
|
||||
result = await run_get(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'target_run_id': 'run_other',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not accessible' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistent_run_get_requires_capability(db_engine):
|
||||
await _create_run(db_engine, available_apis={'run_get': False})
|
||||
handler = _handler(db_engine)
|
||||
run_get = handler.actions[PluginToRuntimeAction.RUN_GET.value]
|
||||
|
||||
result = await run_get(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not authorized' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistent_authorization_does_not_reopen_artifact_api(db_engine):
|
||||
await _create_run(db_engine, available_apis={'artifact_metadata': True})
|
||||
handler = _handler(db_engine)
|
||||
artifact_metadata = handler.actions[PluginToRuntimeAction.ARTIFACT_METADATA.value]
|
||||
|
||||
result = await artifact_metadata(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'artifact_id': 'artifact_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not found or expired' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cancel_basic_path(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'run_cancel': True})
|
||||
await _create_run(db_engine)
|
||||
handler = _handler(db_engine)
|
||||
run_cancel = handler.actions[PluginToRuntimeAction.RUN_CANCEL.value]
|
||||
|
||||
result = await run_cancel(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'reason': 'user requested',
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code == 0
|
||||
run = AgentRun.model_validate(result.data)
|
||||
assert run.run_id == 'run_1'
|
||||
assert run.cancel_requested_at is not None
|
||||
assert run.status_reason == 'user requested'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_append_result_basic_path(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'run_append_result': True})
|
||||
await _create_run(db_engine)
|
||||
handler = _handler(db_engine)
|
||||
run_append_result = handler.actions[PluginToRuntimeAction.RUN_APPEND_RESULT.value]
|
||||
|
||||
result = await run_append_result(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'sequence': 2,
|
||||
'result': {
|
||||
'type': 'message.delta',
|
||||
'data': {'delta': 'hello'},
|
||||
'usage': {'output_tokens': 1},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code == 0
|
||||
event = AgentRunEvent.model_validate(result.data)
|
||||
assert event.run_id == 'run_1'
|
||||
assert event.sequence == 2
|
||||
assert event.type == 'message.delta'
|
||||
assert event.data == {'delta': 'hello'}
|
||||
assert event.usage == {'output_tokens': 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_finalize_basic_path(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'run_finalize': True})
|
||||
await _create_run(db_engine)
|
||||
handler = _handler(db_engine)
|
||||
run_finalize = handler.actions[PluginToRuntimeAction.RUN_FINALIZE.value]
|
||||
|
||||
result = await run_finalize(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'status': 'completed',
|
||||
'status_reason': 'done',
|
||||
'usage': {'total_tokens': 3},
|
||||
}
|
||||
)
|
||||
|
||||
assert result.code == 0
|
||||
run = AgentRun.model_validate(result.data)
|
||||
assert run.status == 'completed'
|
||||
assert run.status_reason == 'done'
|
||||
assert run.finished_at is not None
|
||||
assert run.usage == {'total_tokens': 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_register_heartbeat_and_list_actions(session_registry, db_engine):
|
||||
await _register_session(
|
||||
session_registry,
|
||||
available_apis={
|
||||
'runtime_register': True,
|
||||
'runtime_heartbeat': True,
|
||||
'runtime_list': True,
|
||||
},
|
||||
)
|
||||
handler = _handler(db_engine)
|
||||
runtime_register = handler.actions[PluginToRuntimeAction.RUNTIME_REGISTER.value]
|
||||
runtime_heartbeat = handler.actions[PluginToRuntimeAction.RUNTIME_HEARTBEAT.value]
|
||||
runtime_list = handler.actions[PluginToRuntimeAction.RUNTIME_LIST.value]
|
||||
|
||||
registered = await runtime_register(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'runtime_id': 'runtime_1',
|
||||
'display_name': 'Runtime 1',
|
||||
'capabilities': {'runner': True},
|
||||
'labels': {'region': 'test'},
|
||||
'metadata': {'slots': 2},
|
||||
}
|
||||
)
|
||||
|
||||
assert registered.code == 0
|
||||
assert registered.data['runtime_id'] == 'runtime_1'
|
||||
assert registered.data['capabilities'] == {'runner': True}
|
||||
|
||||
heartbeat = await runtime_heartbeat(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'runtime_id': 'runtime_1',
|
||||
'capabilities': {'runner': True, 'stream': True},
|
||||
'labels': {'region': 'test'},
|
||||
'metadata': {'active_runs': 1},
|
||||
}
|
||||
)
|
||||
|
||||
assert heartbeat.code == 0
|
||||
assert heartbeat.data['capabilities'] == {'runner': True, 'stream': True}
|
||||
assert heartbeat.data['metadata'] == {'slots': 2, 'active_runs': 1}
|
||||
|
||||
page = await runtime_list(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'statuses': ['online'],
|
||||
'labels': {'region': 'test'},
|
||||
}
|
||||
)
|
||||
|
||||
assert page.code == 0
|
||||
assert [item['runtime_id'] for item in page.data['items']] == ['runtime_1']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_claim_renew_and_release_actions(session_registry, db_engine):
|
||||
await _register_session(
|
||||
session_registry,
|
||||
available_apis={
|
||||
'run_claim': True,
|
||||
'run_renew_claim': True,
|
||||
'run_release_claim': True,
|
||||
},
|
||||
)
|
||||
await RunLedgerStore(db_engine).create_run(
|
||||
run_id='queued_run',
|
||||
event_id='evt_queued',
|
||||
binding_id='binding_1',
|
||||
runner_id='plugin:test/runner/default',
|
||||
conversation_id='conv_1',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=5,
|
||||
)
|
||||
handler = _handler(db_engine)
|
||||
run_claim = handler.actions[PluginToRuntimeAction.RUN_CLAIM.value]
|
||||
run_renew_claim = handler.actions[PluginToRuntimeAction.RUN_RENEW_CLAIM.value]
|
||||
run_release_claim = handler.actions[PluginToRuntimeAction.RUN_RELEASE_CLAIM.value]
|
||||
|
||||
claimed = await run_claim(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'runtime_id': 'runtime_1',
|
||||
'queue_name': 'default',
|
||||
'lease_seconds': 30,
|
||||
}
|
||||
)
|
||||
|
||||
assert claimed.code == 0
|
||||
assert claimed.data['run_id'] == 'queued_run'
|
||||
assert claimed.data['status'] == 'claimed'
|
||||
assert claimed.data['claimed_by_runtime_id'] == 'runtime_1'
|
||||
claim_token = claimed.data['claim_token']
|
||||
|
||||
renewed = await run_renew_claim(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'target_run_id': 'queued_run',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'runtime_id': 'runtime_1',
|
||||
'claim_token': claim_token,
|
||||
'lease_seconds': 60,
|
||||
}
|
||||
)
|
||||
|
||||
assert renewed.code == 0
|
||||
assert renewed.data['claim_token'] == claim_token
|
||||
|
||||
released = await run_release_claim(
|
||||
{
|
||||
'run_id': 'run_1',
|
||||
'target_run_id': 'queued_run',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'runtime_id': 'runtime_1',
|
||||
'claim_token': claim_token,
|
||||
'reason': 'done with lease',
|
||||
}
|
||||
)
|
||||
|
||||
assert released.code == 0
|
||||
assert released.data['status'] == 'queued'
|
||||
assert released.data['claimed_by_runtime_id'] is None
|
||||
assert released.data['claim_token'] is None
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Tests for RunLedgerStore host primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from langbot.pkg.agent.runner.run_ledger_store import RunLedgerStore
|
||||
from langbot.pkg.entity.persistence.agent_run import AgentRun
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
|
||||
|
||||
UTC = datetime.timezone.utc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(tmp_path):
|
||||
db_path = tmp_path / 'run_ledger_store.db'
|
||||
engine = create_async_engine(f'sqlite+aiosqlite:///{db_path}', echo=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(db_engine):
|
||||
return RunLedgerStore(db_engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_queued_run_claim_renew_release(store):
|
||||
run = await store.create_run(
|
||||
run_id='run-queued',
|
||||
event_id='evt-1',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
priority=10,
|
||||
requested_runtime_id='runtime-a',
|
||||
)
|
||||
|
||||
assert run['status'] == 'queued'
|
||||
assert run['started_at'] is None
|
||||
assert run['queue_name'] == 'default'
|
||||
assert run['priority'] == 10
|
||||
assert run['requested_runtime_id'] == 'runtime-a'
|
||||
|
||||
assert await store.claim_next_run(runtime_id='runtime-b', queue_name='default') is None
|
||||
|
||||
claimed = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=30)
|
||||
|
||||
assert claimed is not None
|
||||
assert claimed['run_id'] == 'run-queued'
|
||||
assert claimed['status'] == 'claimed'
|
||||
assert claimed['claimed_by_runtime_id'] == 'runtime-a'
|
||||
assert claimed['claim_token']
|
||||
assert claimed['dispatch_attempts'] == 1
|
||||
assert claimed['claim_lease_expires_at'] is not None
|
||||
assert claimed['last_claimed_at'] is not None
|
||||
|
||||
token = claimed['claim_token']
|
||||
assert await store.renew_claim(run_id='run-queued', claim_token='wrong-token') is None
|
||||
|
||||
renewed = await store.renew_claim(run_id='run-queued', claim_token=token, lease_seconds=90)
|
||||
|
||||
assert renewed is not None
|
||||
assert renewed['claim_token'] == token
|
||||
assert renewed['claim_lease_expires_at'] >= claimed['claim_lease_expires_at']
|
||||
|
||||
released = await store.release_claim(
|
||||
run_id='run-queued',
|
||||
claim_token=token,
|
||||
status='queued',
|
||||
status_reason='runtime released capacity',
|
||||
)
|
||||
|
||||
assert released is not None
|
||||
assert released['status'] == 'queued'
|
||||
assert released['status_reason'] == 'runtime released capacity'
|
||||
assert released['claimed_by_runtime_id'] is None
|
||||
assert released['claim_token'] is None
|
||||
assert released['claim_lease_expires_at'] is None
|
||||
assert released['dispatch_attempts'] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_claim_can_be_reclaimed(store, db_engine):
|
||||
await store.create_run(
|
||||
run_id='run-expired',
|
||||
event_id='evt-2',
|
||||
binding_id='binding-1',
|
||||
runner_id='runner-a',
|
||||
status='queued',
|
||||
queue_name='default',
|
||||
)
|
||||
first_claim = await store.claim_next_run(runtime_id='runtime-a', queue_name='default', lease_seconds=60)
|
||||
assert first_claim is not None
|
||||
|
||||
session_factory = sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as session:
|
||||
await session.execute(
|
||||
sqlalchemy.update(AgentRun)
|
||||
.where(AgentRun.run_id == 'run-expired')
|
||||
.values(claim_lease_expires_at=datetime.datetime.now(UTC) - datetime.timedelta(seconds=1))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
reclaimed = await store.claim_next_run(runtime_id='runtime-b', queue_name='default', lease_seconds=60)
|
||||
|
||||
assert reclaimed is not None
|
||||
assert reclaimed['run_id'] == 'run-expired'
|
||||
assert reclaimed['claimed_by_runtime_id'] == 'runtime-b'
|
||||
assert reclaimed['claim_token'] != first_claim['claim_token']
|
||||
assert reclaimed['dispatch_attempts'] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_register_heartbeat_list_and_mark_stale(store):
|
||||
registered = await store.register_runtime(
|
||||
runtime_id='runtime-a',
|
||||
display_name='Runtime A',
|
||||
endpoint='http://runtime-a',
|
||||
version='1.0.0',
|
||||
capabilities={'stream': True},
|
||||
labels={'region': 'test'},
|
||||
metadata={'slot_count': 2},
|
||||
heartbeat_deadline_seconds=30,
|
||||
)
|
||||
|
||||
assert registered['runtime_id'] == 'runtime-a'
|
||||
assert registered['status'] == 'online'
|
||||
assert registered['display_name'] == 'Runtime A'
|
||||
assert registered['capabilities'] == {'stream': True}
|
||||
assert registered['labels'] == {'region': 'test'}
|
||||
assert registered['metadata'] == {'slot_count': 2}
|
||||
assert registered['last_heartbeat_at'] is not None
|
||||
assert registered['heartbeat_deadline_at'] is not None
|
||||
|
||||
heartbeat = await store.heartbeat_runtime(
|
||||
runtime_id='runtime-a',
|
||||
metadata={'active_runs': 1},
|
||||
heartbeat_deadline_seconds=30,
|
||||
)
|
||||
|
||||
assert heartbeat is not None
|
||||
assert heartbeat['metadata'] == {'slot_count': 2, 'active_runs': 1}
|
||||
|
||||
runtimes = await store.list_runtimes(statuses=['online'])
|
||||
assert [runtime['runtime_id'] for runtime in runtimes] == ['runtime-a']
|
||||
|
||||
stale = await store.mark_stale_runtimes(
|
||||
now=datetime.datetime.now(UTC) + datetime.timedelta(seconds=31),
|
||||
)
|
||||
|
||||
assert [runtime['runtime_id'] for runtime in stale] == ['runtime-a']
|
||||
assert stale[0]['status'] == 'stale'
|
||||
assert (await store.get_runtime('runtime-a'))['status'] == 'stale'
|
||||
Reference in New Issue
Block a user