diff --git a/src/langbot/pkg/agent/runner/artifact_store.py b/src/langbot/pkg/agent/runner/artifact_store.py new file mode 100644 index 00000000..299f2ff0 --- /dev/null +++ b/src/langbot/pkg/agent/runner/artifact_store.py @@ -0,0 +1,300 @@ +"""Artifact store for managing Host-owned artifacts.""" +from __future__ import annotations + +import json +import datetime +import typing +import uuid +import base64 + +import sqlalchemy +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from sqlalchemy.orm import sessionmaker + +from ...entity.persistence.artifact import AgentArtifact +from ...entity.persistence.bstorage import BinaryStorage + + +class ArtifactStore: + """Store for AgentArtifact records. + + Handles artifact metadata registration and content retrieval. + Actual blob storage is delegated to BinaryStorage or external storage. + + All methods are async and use the provided database engine. + """ + + engine: AsyncEngine + + # Hard limits + MAX_INLINE_READ_BYTES = 1024 * 1024 # 1MB max for inline base64 + MAX_RANGE_READ_BYTES = 10 * 1024 * 1024 # 10MB max for range reads + + def __init__(self, engine: AsyncEngine): + self.engine = engine + self._session_factory = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + + async def register_artifact( + self, + artifact_id: str | None, + artifact_type: str, + source: str, + storage_key: str | None = None, + storage_type: str = 'binary_storage', + mime_type: str | None = None, + name: str | None = None, + size_bytes: int | None = None, + sha256: str | None = None, + conversation_id: str | None = None, + run_id: str | None = None, + runner_id: str | None = None, + bot_id: str | None = None, + workspace_id: str | None = None, + expires_at: datetime.datetime | None = None, + metadata: dict[str, typing.Any] | None = None, + content: bytes | None = None, + ) -> str: + """Register a new artifact. + + If content is provided and storage_key is None, stores content + in BinaryStorage automatically. + + Args: + artifact_id: Unique artifact ID (generated if None) + artifact_type: Type of artifact (image, file, voice, tool_result, etc.) + source: Source of artifact (platform, runner, tool, system) + storage_key: Key in BinaryStorage or external reference + storage_type: Storage type (binary_storage, file, url) + mime_type: MIME type + name: Original file name + size_bytes: Size in bytes + sha256: SHA256 hash + conversation_id: Conversation ID + run_id: Run ID that created this + runner_id: Runner ID that created this + bot_id: Bot UUID + workspace_id: Workspace ID + expires_at: Expiration time + metadata: Additional metadata + content: Optional content to store in BinaryStorage + + Returns: + The artifact_id + """ + if artifact_id is None: + artifact_id = str(uuid.uuid4()) + + # If content provided, store in BinaryStorage + if content is not None and storage_key is None: + storage_key = f"artifact:{artifact_id}" + storage_type = 'binary_storage' + if size_bytes is None: + size_bytes = len(content) + + async with self._session_factory() as session: + # Store content in BinaryStorage if provided + if content is not None: + binary_storage = BinaryStorage( + unique_key=f'artifact:{artifact_id}', + key=storage_key, + owner_type='artifact', + owner='host', + value=content, + ) + session.add(binary_storage) + + # Store artifact metadata + artifact = AgentArtifact( + artifact_id=artifact_id, + artifact_type=artifact_type, + mime_type=mime_type, + name=name, + size_bytes=size_bytes, + sha256=sha256, + source=source, + storage_key=storage_key, + storage_type=storage_type, + conversation_id=conversation_id, + run_id=run_id, + runner_id=runner_id, + bot_id=bot_id, + workspace_id=workspace_id, + created_at=datetime.datetime.utcnow(), + expires_at=expires_at, + metadata_json=json.dumps(metadata) if metadata else None, + ) + session.add(artifact) + await session.commit() + + return artifact_id + + async def get_metadata( + self, + artifact_id: str, + ) -> dict[str, typing.Any] | None: + """Get artifact metadata (public fields only, no internal storage info). + + Args: + artifact_id: Artifact ID + + Returns: + Artifact metadata dict compatible with SDK ArtifactMetadata, or None if not found + """ + async with self._session_factory() as session: + result = await session.execute( + sqlalchemy.select(AgentArtifact).where( + AgentArtifact.artifact_id == artifact_id + ) + ) + row = result.scalars().first() + if row is None: + return None + return self._row_to_public_dict(row) + + async def _get_internal_record( + self, + artifact_id: str, + ) -> AgentArtifact | None: + """Get full artifact record including internal fields. + + Used internally by read_artifact to access storage_key/storage_type. + + Args: + artifact_id: Artifact ID + + Returns: + AgentArtifact ORM instance, or None if not found + """ + async with self._session_factory() as session: + result = await session.execute( + sqlalchemy.select(AgentArtifact).where( + AgentArtifact.artifact_id == artifact_id + ) + ) + return result.scalars().first() + + async def read_artifact( + self, + artifact_id: str, + offset: int = 0, + limit: int | None = None, + ) -> dict[str, typing.Any] | None: + """Read artifact content. + + For small artifacts, returns content_base64 directly. + For large artifacts, returns file_key for chunked transfer. + + Args: + artifact_id: Artifact ID + offset: Byte offset to start reading from (must be >= 0) + limit: Maximum bytes to read (must be > 0 if provided) + + Returns: + ArtifactReadResult dict, or None if not found + + Raises: + ValueError: If offset < 0 or limit <= 0 + """ + # Validate offset and limit + if offset < 0: + raise ValueError("offset must be >= 0") + + if limit is not None and limit <= 0: + raise ValueError("limit must be > 0") + + # Get internal record (includes storage_key/storage_type) + record = await self._get_internal_record(artifact_id) + if record is None: + return None + + storage_type = record.storage_type or 'binary_storage' + storage_key = record.storage_key + size_bytes = record.size_bytes or 0 + + # Cap limit at hard limit + if limit is None: + limit = self.MAX_INLINE_READ_BYTES + limit = min(limit, self.MAX_RANGE_READ_BYTES) + + # For binary_storage, read content + if storage_type == 'binary_storage' and storage_key: + content = await self._read_binary_storage(storage_key) + if content is None: + return None + + # Apply offset and limit + if offset > 0: + content = content[offset:] + if limit and len(content) > limit: + content = content[:limit] + has_more = True + else: + has_more = False + + return { + 'artifact_id': artifact_id, + 'mime_type': record.mime_type, + 'size_bytes': size_bytes, + 'offset': offset, + 'length': len(content), + 'content_base64': base64.b64encode(content).decode('utf-8'), + 'file_key': None, + 'has_more': has_more, + } + + # For other storage types, return storage reference + # (caller can use file_key for chunked transfer) + return { + 'artifact_id': artifact_id, + 'mime_type': record.mime_type, + 'size_bytes': size_bytes, + 'offset': offset, + 'length': None, + 'content_base64': None, + 'file_key': storage_key, + 'has_more': False, + } + + async def _read_binary_storage(self, key: str) -> bytes | None: + """Read content from BinaryStorage. + + Uses unique_key for isolation to prevent cross-artifact access. + + Args: + key: The unique_key used when storing the artifact + + Returns: + Content bytes, or None if not found + """ + async with self._session_factory() as session: + result = await session.execute( + sqlalchemy.select(BinaryStorage).where(BinaryStorage.unique_key == key) + ) + row = result.scalars().first() + if row is None: + return None + return row.value + + def _row_to_public_dict(self, row: AgentArtifact) -> dict[str, typing.Any]: + """Convert an AgentArtifact row to public dict. + + Returns only fields that match SDK ArtifactMetadata entity. + Host-only fields (bot_id, workspace_id, storage_key, storage_type) are excluded. + """ + return { + 'artifact_id': row.artifact_id, + 'artifact_type': row.artifact_type, + 'mime_type': row.mime_type, + 'name': row.name, + 'size_bytes': row.size_bytes, + 'sha256': row.sha256, + 'source': row.source, + 'conversation_id': row.conversation_id, + 'run_id': row.run_id, + 'runner_id': row.runner_id, + 'created_at': int(row.created_at.timestamp()) if row.created_at else None, + 'expires_at': int(row.expires_at.timestamp()) if row.expires_at else None, + 'metadata': json.loads(row.metadata_json) if row.metadata_json else {}, + } diff --git a/src/langbot/pkg/agent/runner/context_builder.py b/src/langbot/pkg/agent/runner/context_builder.py index 95fb2606..14fb609c 100644 --- a/src/langbot/pkg/agent/runner/context_builder.py +++ b/src/langbot/pkg/agent/runner/context_builder.py @@ -891,11 +891,14 @@ class AgentRunContextBuilder: permissions = descriptor.permissions or {} history_permissions = permissions.get('history', []) event_permissions = permissions.get('events', []) + artifact_permissions = permissions.get('artifacts', []) history_page_enabled = 'page' in history_permissions and conversation_id is not None history_search_enabled = 'search' in history_permissions and conversation_id is not None event_get_enabled = 'get' in event_permissions event_page_enabled = 'page' in event_permissions and conversation_id is not None + artifact_metadata_enabled = 'metadata' in artifact_permissions + artifact_read_enabled = 'read' in artifact_permissions # Get latest cursor and has_history_before if conversation exists latest_cursor = None @@ -931,8 +934,8 @@ class AgentRunContextBuilder: 'history_search': history_search_enabled, 'event_get': event_get_enabled, 'event_page': event_page_enabled, - 'artifact_metadata': False, # TODO: Implement artifact store - 'artifact_read': False, + 'artifact_metadata': artifact_metadata_enabled, + 'artifact_read': artifact_read_enabled, 'state': True, 'storage': True, }, diff --git a/src/langbot/pkg/agent/runner/event_log_store.py b/src/langbot/pkg/agent/runner/event_log_store.py index 6eb08361..0b693b19 100644 --- a/src/langbot/pkg/agent/runner/event_log_store.py +++ b/src/langbot/pkg/agent/runner/event_log_store.py @@ -7,7 +7,8 @@ import typing import uuid import sqlalchemy -from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from sqlalchemy.orm import sessionmaker from ...entity.persistence.event_log import EventLog from ...entity.persistence.transcript import Transcript @@ -27,6 +28,9 @@ class EventLogStore: def __init__(self, engine: AsyncEngine): self.engine = engine + self._session_factory = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) async def append_event( self, @@ -83,32 +87,31 @@ class EventLogStore: if input_summary and len(input_summary) > self.MAX_INPUT_SUMMARY_LENGTH: input_summary = input_summary[:self.MAX_INPUT_SUMMARY_LENGTH - 3] + "..." - async with self.engine.connect() as conn: - await conn.execute( - sqlalchemy.insert(EventLog).values( - event_id=event_id, - event_type=event_type, - event_time=event_time, - source=source, - bot_id=bot_id, - workspace_id=workspace_id, - conversation_id=conversation_id, - thread_id=thread_id, - actor_type=actor_type, - actor_id=actor_id, - actor_name=actor_name, - subject_type=subject_type, - subject_id=subject_id, - input_summary=input_summary, - input_json=json.dumps(input_json) if input_json else None, - raw_ref=raw_ref, - run_id=run_id, - runner_id=runner_id, - metadata_json=json.dumps(metadata) if metadata else None, - created_at=datetime.datetime.utcnow(), - ) + async with self._session_factory() as session: + event = EventLog( + event_id=event_id, + event_type=event_type, + event_time=event_time, + source=source, + bot_id=bot_id, + workspace_id=workspace_id, + conversation_id=conversation_id, + thread_id=thread_id, + actor_type=actor_type, + actor_id=actor_id, + actor_name=actor_name, + subject_type=subject_type, + subject_id=subject_id, + input_summary=input_summary, + input_json=json.dumps(input_json) if input_json else None, + raw_ref=raw_ref, + run_id=run_id, + runner_id=runner_id, + metadata_json=json.dumps(metadata) if metadata else None, + created_at=datetime.datetime.utcnow(), ) - await conn.commit() + session.add(event) + await session.commit() return event_id @@ -124,14 +127,14 @@ class EventLogStore: Returns: Event record as dict, or None if not found """ - async with self.engine.connect() as conn: - result = await conn.execute( + async with self._session_factory() as session: + result = await session.execute( sqlalchemy.select(EventLog).where(EventLog.event_id == event_id) ) - row = result.fetchone() + row = result.scalars().first() if row is None: return None - return self._row_to_dict(row[0]) + return self._row_to_dict(row) async def page_events( self, @@ -153,7 +156,7 @@ class EventLogStore: """ limit = min(limit, 100) # Hard cap - async with self.engine.connect() as conn: + async with self._session_factory() as session: query = sqlalchemy.select(EventLog) if conversation_id is not None: @@ -167,10 +170,10 @@ class EventLogStore: query = query.order_by(EventLog.id.desc()).limit(limit + 1) - result = await conn.execute(query) - rows = result.fetchall() + result = await session.execute(query) + rows = result.scalars().all() - items = [self._row_to_dict(row[0]) for row in rows[:limit]] + items = [self._row_to_dict(row) for row in rows[:limit]] has_more = len(rows) > limit next_seq = items[-1]['id'] if items and has_more else None @@ -188,17 +191,17 @@ class EventLogStore: Returns: Cursor string (seq number), or None if no events """ - async with self.engine.connect() as conn: - result = await conn.execute( + async with self._session_factory() as session: + result = await session.execute( sqlalchemy.select(EventLog.id) .where(EventLog.conversation_id == conversation_id) .order_by(EventLog.id.desc()) .limit(1) ) - row = result.fetchone() + row = result.scalars().first() if row is None: return None - return str(row[0]) + return str(row) async def has_events_before( self, @@ -214,8 +217,8 @@ class EventLogStore: Returns: True if there are events before """ - async with self.engine.connect() as conn: - result = await conn.execute( + async with self._session_factory() as session: + result = await session.execute( sqlalchemy.select(sqlalchemy.func.count()) .select_from(EventLog) .where( diff --git a/src/langbot/pkg/agent/runner/orchestrator.py b/src/langbot/pkg/agent/runner/orchestrator.py index cf0e1d83..77f26968 100644 --- a/src/langbot/pkg/agent/runner/orchestrator.py +++ b/src/langbot/pkg/agent/runner/orchestrator.py @@ -125,6 +125,7 @@ class AgentRunOrchestrator: query_id=None, # No query_id in event-first mode plugin_identity=descriptor.get_plugin_id(), resources=resources, + permissions=descriptor.permissions or {}, conversation_id=event.conversation_id, ) @@ -222,6 +223,7 @@ class AgentRunOrchestrator: query_id=query.query_id, plugin_identity=descriptor.get_plugin_id(), resources=resources, + permissions=descriptor.permissions or {}, conversation_id=conversation_id, ) diff --git a/src/langbot/pkg/agent/runner/session_registry.py b/src/langbot/pkg/agent/runner/session_registry.py index 1d24f593..6a0dca3e 100644 --- a/src/langbot/pkg/agent/runner/session_registry.py +++ b/src/langbot/pkg/agent/runner/session_registry.py @@ -27,6 +27,7 @@ class AgentRunSession(typing.TypedDict): plugin_identity: Plugin identifier (author/name) of the runner conversation_id: Conversation ID for history/event access resources: Authorized resources for this run (from AgentResources) + permissions: Runner permissions from descriptor (artifacts, history, events, etc.) status: Session status tracking _authorized_ids: Pre-computed authorized resource IDs for O(1) lookup """ @@ -36,6 +37,7 @@ class AgentRunSession(typing.TypedDict): plugin_identity: str # author/name conversation_id: str | None resources: AgentResources + permissions: dict[str, list[str]] status: AgentRunSessionStatus _authorized_ids: dict[str, set[str]] # Pre-computed sets for O(1) lookup @@ -67,6 +69,7 @@ class AgentRunSessionRegistry: plugin_identity: str, resources: AgentResources, conversation_id: str | None = None, + permissions: dict[str, list[str]] | None = None, ) -> None: """Register a new agent run session. @@ -77,9 +80,13 @@ class AgentRunSessionRegistry: plugin_identity: Plugin identifier (author/name) resources: Authorized resources for this run conversation_id: Conversation ID for history/event access + permissions: Runner permissions from descriptor (artifacts, history, events, etc.) """ now = int(time.time()) + # Normalize permissions to empty dict if None + permissions = permissions or {} + # Pre-compute authorized resource IDs for O(1) lookup authorized_ids: dict[str, set[str]] = { 'model': {m.get('model_id') for m in resources.get('models', [])}, @@ -95,6 +102,7 @@ class AgentRunSessionRegistry: 'plugin_identity': plugin_identity, 'conversation_id': conversation_id, 'resources': resources, + 'permissions': permissions, 'status': { 'started_at': now, 'last_activity_at': now, diff --git a/src/langbot/pkg/agent/runner/transcript_store.py b/src/langbot/pkg/agent/runner/transcript_store.py index ea63c427..05064525 100644 --- a/src/langbot/pkg/agent/runner/transcript_store.py +++ b/src/langbot/pkg/agent/runner/transcript_store.py @@ -7,7 +7,8 @@ import typing import uuid import sqlalchemy -from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from sqlalchemy.orm import sessionmaker from ...entity.persistence.transcript import Transcript @@ -27,6 +28,9 @@ class TranscriptStore: def __init__(self, engine: AsyncEngine): self.engine = engine + self._session_factory = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) async def append_transcript( self, @@ -72,26 +76,25 @@ class TranscriptStore: # Get next sequence number for this conversation seq = await self._get_next_seq(conversation_id) - async with self.engine.connect() as conn: - await conn.execute( - sqlalchemy.insert(Transcript).values( - transcript_id=transcript_id, - event_id=event_id, - conversation_id=conversation_id, - thread_id=thread_id, - role=role, - item_type=item_type, - content=content, - content_json=json.dumps(content_json) if content_json else None, - artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None, - seq=seq, - run_id=run_id, - runner_id=runner_id, - created_at=datetime.datetime.utcnow(), - metadata_json=json.dumps(metadata) if metadata else None, - ) + async with self._session_factory() as session: + item = Transcript( + transcript_id=transcript_id, + event_id=event_id, + conversation_id=conversation_id, + thread_id=thread_id, + role=role, + item_type=item_type, + content=content, + content_json=json.dumps(content_json) if content_json else None, + artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None, + seq=seq, + run_id=run_id, + runner_id=runner_id, + created_at=datetime.datetime.utcnow(), + metadata_json=json.dumps(metadata) if metadata else None, ) - await conn.commit() + session.add(item) + await session.commit() return transcript_id @@ -119,7 +122,7 @@ class TranscriptStore: """ limit = min(limit, self.HARD_LIMIT) - async with self.engine.connect() as conn: + async with self._session_factory() as session: query = sqlalchemy.select(Transcript).where( Transcript.conversation_id == conversation_id ) @@ -136,10 +139,10 @@ class TranscriptStore: query = query.limit(limit + 1) - result = await conn.execute(query) - rows = result.fetchall() + result = await session.execute(query) + rows = result.scalars().all() - items = [self._row_to_dict(row[0], include_artifacts) for row in rows[:limit]] + items = [self._row_to_dict(row, include_artifacts) for row in rows[:limit]] has_more = len(rows) > limit # Calculate cursors @@ -179,7 +182,7 @@ class TranscriptStore: Returns: List of matching items """ - async with self.engine.connect() as conn: + async with self._session_factory() as session: query = sqlalchemy.select(Transcript).where( Transcript.conversation_id == conversation_id, Transcript.content.ilike(f"%{query_text}%"), @@ -194,10 +197,10 @@ class TranscriptStore: query = query.order_by(Transcript.seq.desc()).limit(top_k) - result = await conn.execute(query) - rows = result.fetchall() + result = await session.execute(query) + rows = result.scalars().all() - return [self._row_to_dict(row[0], include_artifacts=True) for row in rows] + return [self._row_to_dict(row, include_artifacts=True) for row in rows] async def get_latest_cursor( self, @@ -211,17 +214,17 @@ class TranscriptStore: Returns: Cursor string (seq number), or None if no items """ - async with self.engine.connect() as conn: - result = await conn.execute( + async with self._session_factory() as session: + result = await session.execute( sqlalchemy.select(Transcript.seq) .where(Transcript.conversation_id == conversation_id) .order_by(Transcript.seq.desc()) .limit(1) ) - row = result.fetchone() + row = result.scalars().first() if row is None: return None - return str(row[0]) + return str(row) async def has_history_before( self, @@ -237,8 +240,8 @@ class TranscriptStore: Returns: True if there are items before """ - async with self.engine.connect() as conn: - result = await conn.execute( + async with self._session_factory() as session: + result = await session.execute( sqlalchemy.select(sqlalchemy.func.count()) .select_from(Transcript) .where( @@ -251,8 +254,8 @@ class TranscriptStore: async def _get_next_seq(self, conversation_id: str) -> int: """Get the next sequence number for a conversation.""" - async with self.engine.connect() as conn: - result = await conn.execute( + async with self._session_factory() as session: + result = await session.execute( sqlalchemy.select(sqlalchemy.func.max(Transcript.seq)) .where(Transcript.conversation_id == conversation_id) ) diff --git a/src/langbot/pkg/entity/persistence/artifact.py b/src/langbot/pkg/entity/persistence/artifact.py new file mode 100644 index 00000000..2d4683e8 --- /dev/null +++ b/src/langbot/pkg/entity/persistence/artifact.py @@ -0,0 +1,77 @@ +"""Artifact persistence entity for Host-owned artifact store.""" +from __future__ import annotations + +import sqlalchemy +import datetime + +from .base import Base + + +class AgentArtifact(Base): + """AgentArtifact stores metadata for large files, images, tool results, etc. + + This table only stores metadata. The actual blob content is stored in + BinaryStorage or external storage, referenced by storage_key. + + Artifacts are accessed via artifact_metadata and artifact_read APIs + with run_id authorization. + """ + + __tablename__ = 'agent_artifact' + + id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True) + """Auto-increment ID for sequencing.""" + + artifact_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True) + """Unique artifact identifier.""" + + artifact_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) + """Artifact type: 'image', 'file', 'voice', 'tool_result', 'platform_attachment', etc.""" + + mime_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """MIME type of the content.""" + + name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Original file name (if applicable).""" + + size_bytes = sqlalchemy.Column(sqlalchemy.BigInteger, nullable=True) + """Size in bytes.""" + + sha256 = sqlalchemy.Column(sqlalchemy.String(64), nullable=True) + """SHA256 hash of content (for integrity verification).""" + + source = sqlalchemy.Column(sqlalchemy.String(50), nullable=False) + """Source of artifact: 'platform', 'runner', 'tool', 'system'.""" + + # Storage reference (points to BinaryStorage or external storage) + storage_key = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Key in BinaryStorage or external storage reference.""" + + storage_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, default='binary_storage') + """Storage type: 'binary_storage', 'file', 'url', etc.""" + + # Context + conversation_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Conversation this artifact belongs to.""" + + run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) + """Run ID that created this artifact.""" + + runner_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Runner ID that created this artifact.""" + + bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Bot UUID that handled this artifact.""" + + workspace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + """Workspace ID for multi-tenant deployments.""" + + # Lifecycle + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow) + """When this artifact was created.""" + + expires_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + """When this artifact expires (optional).""" + + metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True) + """Additional metadata as JSON string.""" diff --git a/src/langbot/pkg/persistence/alembic/env.py b/src/langbot/pkg/persistence/alembic/env.py index 43fae5ab..6cb6d5b0 100644 --- a/src/langbot/pkg/persistence/alembic/env.py +++ b/src/langbot/pkg/persistence/alembic/env.py @@ -17,6 +17,7 @@ from langbot.pkg.entity.persistence.base import Base # This is required for autogenerate to detect model changes from langbot.pkg.entity.persistence import ( apikey, + artifact, bot, bstorage, event_log, diff --git a/src/langbot/pkg/persistence/alembic/versions/a1b2c3d4e5f6_add_agent_artifact_table.py b/src/langbot/pkg/persistence/alembic/versions/a1b2c3d4e5f6_add_agent_artifact_table.py new file mode 100644 index 00000000..244d3e45 --- /dev/null +++ b/src/langbot/pkg/persistence/alembic/versions/a1b2c3d4e5f6_add_agent_artifact_table.py @@ -0,0 +1,55 @@ +"""add_agent_artifact_table + +Revision ID: a1b2c3d4e5f6 +Revises: 58846a8d7a81 +Create Date: 2026-05-23 20:00:00.000000 +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers +revision = 'a1b2c3d4e5f6' +down_revision = '58846a8d7a81' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Create agent_artifact table + op.create_table( + 'agent_artifact', + sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), + sa.Column('artifact_id', sa.String(255), nullable=False, unique=True), + sa.Column('artifact_type', sa.String(50), nullable=False), + sa.Column('mime_type', sa.String(255), nullable=True), + sa.Column('name', sa.String(255), nullable=True), + sa.Column('size_bytes', sa.BigInteger(), nullable=True), + sa.Column('sha256', sa.String(64), nullable=True), + sa.Column('source', sa.String(50), nullable=False), + sa.Column('storage_key', sa.String(255), nullable=True), + sa.Column('storage_type', sa.String(50), nullable=False, server_default='binary_storage'), + sa.Column('conversation_id', sa.String(255), nullable=True), + sa.Column('run_id', sa.String(255), nullable=True), + sa.Column('runner_id', sa.String(255), nullable=True), + sa.Column('bot_id', sa.String(255), nullable=True), + sa.Column('workspace_id', sa.String(255), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')), + sa.Column('expires_at', sa.DateTime(), nullable=True), + sa.Column('metadata_json', sa.Text(), nullable=True), + ) + + # Create indexes for agent_artifact + with op.batch_alter_table('agent_artifact', schema=None) as batch_op: + batch_op.create_index('ix_agent_artifact_artifact_id', ['artifact_id'], unique=True) + batch_op.create_index('ix_agent_artifact_conversation_id', ['conversation_id'], unique=False) + batch_op.create_index('ix_agent_artifact_run_id', ['run_id'], unique=False) + + +def downgrade() -> None: + # Drop agent_artifact table + with op.batch_alter_table('agent_artifact', schema=None) as batch_op: + batch_op.drop_index('ix_agent_artifact_run_id') + batch_op.drop_index('ix_agent_artifact_conversation_id') + batch_op.drop_index('ix_agent_artifact_artifact_id') + + op.drop_table('agent_artifact') diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index 334218c3..2c6ca1a2 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -100,6 +100,47 @@ def _build_tool_detail(tool: Any, requested_tool_name: str | None = None) -> dic } +def _validate_artifact_access( + session: dict[str, Any], + artifact_metadata: dict[str, Any], + operation: str, +) -> tuple[bool, str | None]: + """Validate artifact access for a run session. + + Authorization rules (evaluated in order, first match wins): + 1. Artifact run_id matches session run_id → ALLOW (created by this run) + 2. Artifact has conversation_id AND matches session conversation_id → ALLOW (same conversation) + 3. Otherwise → DENY + + Note: Artifacts without conversation_id are NOT globally accessible by default. + Without an explicit scope field, we enforce strict access control. + + Args: + session: AgentRunSession dict with run_id, conversation_id, permissions + artifact_metadata: Artifact metadata dict with conversation_id, run_id + operation: Operation name for error messages ('metadata' or 'read') + + Returns: + Tuple of (is_allowed, error_message). If is_allowed is False, error_message contains reason. + """ + artifact_conversation_id = artifact_metadata.get('conversation_id') + artifact_run_id = artifact_metadata.get('run_id') + session_conversation_id = session.get('conversation_id') + session_run_id = session.get('run_id') + + # Rule 1: Created by this run (allows cross-conversation access for self-created artifacts) + if artifact_run_id and artifact_run_id == session_run_id: + return True, None + + # Rule 2: Same conversation (requires artifact to have conversation_id) + if artifact_conversation_id and session_conversation_id: + if artifact_conversation_id == session_conversation_id: + return True, None + + # Rule 3: Deny - no matching authorization rule + return False, f'Artifact {operation} access denied: artifact not in session conversation and not created by this run' + + def _normalize_uuid_list(values: Any) -> list[str]: """Normalize a user/config supplied UUID list while preserving order.""" if not isinstance(values, list): @@ -1542,6 +1583,169 @@ class RuntimeConnectionHandler(handler.Handler): self.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True) return handler.ActionResponse.error(message=f'Event page error: {e}') + # ================= Artifact APIs ================= + + @self.action(PluginToRuntimeAction.ARTIFACT_METADATA) + async def artifact_metadata(data: dict[str, Any]) -> handler.ActionResponse: + """Get artifact metadata. + + Requires run_id authorization. Only allows access to artifacts + in current run's conversation or created by current run. + """ + run_id = data.get('run_id') + artifact_id = data.get('artifact_id') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + + if not artifact_id: + return handler.ActionResponse.error(message='artifact_id is required') + + # Validate run session + session_registry = get_session_registry() + session = await session_registry.get(run_id) + if not session: + return handler.ActionResponse.error( + message=f'Run session {run_id} not found or expired' + ) + + # Validate caller plugin identity + if caller_plugin_identity: + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity and caller_plugin_identity != session_plugin_identity: + return handler.ActionResponse.error( + message=f'Plugin identity mismatch for run_id {run_id}' + ) + + # Check artifact permission from session.permissions (from descriptor.permissions) + permissions = session.get('permissions', {}) + artifact_permissions = permissions.get('artifacts', []) + if 'metadata' not in artifact_permissions: + return handler.ActionResponse.error( + message='Artifact metadata access not authorized' + ) + + # Get artifact metadata + from ..agent.runner.artifact_store import ArtifactStore + store = ArtifactStore(self.ap.persistence_mgr.get_db_engine()) + + try: + metadata = await store.get_metadata(artifact_id) + if not metadata: + return handler.ActionResponse.error( + message=f'Artifact {artifact_id} not found' + ) + + # Validate artifact access scope + is_allowed, error_msg = _validate_artifact_access(session, metadata, 'metadata') + if not is_allowed: + return handler.ActionResponse.error(message=error_msg) + + return handler.ActionResponse.success(data=metadata) + except Exception as e: + self.ap.logger.error(f'ARTIFACT_METADATA error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Artifact metadata error: {e}') + + @self.action(PluginToRuntimeAction.ARTIFACT_READ) + async def artifact_read(data: dict[str, Any]) -> handler.ActionResponse: + """Read artifact content. + + Requires run_id authorization. Only allows access to artifacts + in current run's conversation or created by current run. + Supports range reads with offset/limit. + """ + run_id = data.get('run_id') + artifact_id = data.get('artifact_id') + caller_plugin_identity = data.get('caller_plugin_identity') + + if not run_id: + return handler.ActionResponse.error(message='run_id is required') + + if not artifact_id: + return handler.ActionResponse.error(message='artifact_id is required') + + # Validate and parse offset + offset = data.get('offset', 0) + if not isinstance(offset, int): + try: + offset = int(offset) + except (TypeError, ValueError): + return handler.ActionResponse.error(message='offset must be an integer') + if offset < 0: + return handler.ActionResponse.error(message='offset must be >= 0') + + # Validate and parse limit if provided + limit = data.get('limit') + if limit is not None: + if not isinstance(limit, int): + try: + limit = int(limit) + except (TypeError, ValueError): + return handler.ActionResponse.error(message='limit must be an integer') + if limit <= 0: + return handler.ActionResponse.error(message='limit must be > 0') + + # Validate run session + session_registry = get_session_registry() + session = await session_registry.get(run_id) + if not session: + return handler.ActionResponse.error( + message=f'Run session {run_id} not found or expired' + ) + + # Validate caller plugin identity + if caller_plugin_identity: + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity and caller_plugin_identity != session_plugin_identity: + return handler.ActionResponse.error( + message=f'Plugin identity mismatch for run_id {run_id}' + ) + + # Check artifact permission from session.permissions (from descriptor.permissions) + permissions = session.get('permissions', {}) + artifact_permissions = permissions.get('artifacts', []) + if 'read' not in artifact_permissions: + return handler.ActionResponse.error( + message='Artifact read access not authorized' + ) + + # Get artifact metadata first to validate access + from ..agent.runner.artifact_store import ArtifactStore + store = ArtifactStore(self.ap.persistence_mgr.get_db_engine()) + + try: + metadata = await store.get_metadata(artifact_id) + if not metadata: + return handler.ActionResponse.error( + message=f'Artifact {artifact_id} not found' + ) + + # Validate artifact access scope + is_allowed, error_msg = _validate_artifact_access(session, metadata, 'read') + if not is_allowed: + return handler.ActionResponse.error(message=error_msg) + + # Read artifact content (validates offset/limit internally) + result = await store.read_artifact( + artifact_id=artifact_id, + offset=offset, + limit=limit, + ) + + if not result: + return handler.ActionResponse.error( + message=f'Failed to read artifact {artifact_id}' + ) + + return handler.ActionResponse.success(data=result) + except ValueError as e: + # Offset/limit validation error + return handler.ActionResponse.error(message=str(e)) + except Exception as e: + self.ap.logger.error(f'ARTIFACT_READ error: {e}', exc_info=True) + return handler.ActionResponse.error(message=f'Artifact read error: {e}') + @self.action(CommonAction.PING) async def ping(data: dict[str, Any]) -> handler.ActionResponse: """Ping""" diff --git a/tests/unit_tests/agent/test_artifact_store.py b/tests/unit_tests/agent/test_artifact_store.py new file mode 100644 index 00000000..1b5607f6 --- /dev/null +++ b/tests/unit_tests/agent/test_artifact_store.py @@ -0,0 +1,625 @@ +"""Tests for ArtifactStore and artifact action handlers.""" +from __future__ import annotations + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +import base64 +import datetime +import asyncio + +from langbot.pkg.agent.runner.artifact_store import ArtifactStore +from langbot.pkg.agent.runner.session_registry import ( + AgentRunSessionRegistry, + get_session_registry, +) + + +class TestArtifactStore: + """Test ArtifactStore operations.""" + + def _make_mock_engine(self): + """Create a mock database engine for AsyncSession-based store. + + Note: The new store uses AsyncSession, so we need to mock + the session factory behavior. + """ + from unittest.mock import MagicMock, AsyncMock, patch + from sqlalchemy.ext.asyncio import AsyncEngine + + engine = MagicMock(spec=AsyncEngine) + return engine + + @pytest.mark.asyncio + async def test_register_artifact_generates_id(self): + """Test register_artifact generates ID if not provided.""" + engine = self._make_mock_engine() + store = ArtifactStore(engine) + + # Mock the session factory + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + artifact_id = await store.register_artifact( + artifact_id=None, + artifact_type="image", + source="platform", + ) + + assert artifact_id is not None + assert len(artifact_id) == 36 # UUID format + + @pytest.mark.asyncio + async def test_register_artifact_with_content(self): + """Test register_artifact stores content in BinaryStorage.""" + engine = self._make_mock_engine() + store = ArtifactStore(engine) + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + content = b"test image content" + artifact_id = await store.register_artifact( + artifact_id="art_001", + artifact_type="image", + source="platform", + content=content, + ) + + assert artifact_id == "art_001" + + @pytest.mark.asyncio + async def test_register_artifact_with_storage_key(self): + """Test register_artifact with pre-existing storage_key.""" + engine = self._make_mock_engine() + store = ArtifactStore(engine) + + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + artifact_id = await store.register_artifact( + artifact_id="art_002", + artifact_type="file", + source="runner", + storage_key="existing_key", + storage_type="binary_storage", + size_bytes=1024, + ) + + assert artifact_id == "art_002" + + @pytest.mark.asyncio + async def test_get_metadata_not_found(self): + """Test get_metadata returns None if not found.""" + engine = self._make_mock_engine() + store = ArtifactStore(engine) + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + metadata = await store.get_metadata("nonexistent") + + assert metadata is None + + @pytest.mark.asyncio + async def test_read_artifact_validates_offset(self): + """Test read_artifact rejects negative offset.""" + engine = self._make_mock_engine() + store = ArtifactStore(engine) + + with pytest.raises(ValueError, match="offset must be >= 0"): + await store.read_artifact("art_001", offset=-1) + + @pytest.mark.asyncio + async def test_read_artifact_validates_limit(self): + """Test read_artifact rejects zero or negative limit.""" + engine = self._make_mock_engine() + store = ArtifactStore(engine) + + with pytest.raises(ValueError, match="limit must be > 0"): + await store.read_artifact("art_001", limit=0) + + with pytest.raises(ValueError, match="limit must be > 0"): + await store.read_artifact("art_001", limit=-5) + + @pytest.mark.asyncio + async def test_read_artifact_not_found(self): + """Test read_artifact returns None if not found.""" + engine = self._make_mock_engine() + store = ArtifactStore(engine) + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + result = await store.read_artifact("nonexistent") + assert result is None + + +class TestArtifactAuthorization: + """Test artifact action handler authorization.""" + + @pytest.fixture + def mock_session_registry(self): + """Create a fresh session registry for testing.""" + # Reset global registry + import langbot.pkg.agent.runner.session_registry as reg + reg._global_registry = None + return get_session_registry() + + @pytest.fixture + def mock_handler(self): + """Create a mock handler for testing actions.""" + from langbot_plugin.runtime.io.handler import Handler + + class MockHandler(Handler): + def __init__(self): + self._responses = {} + + async def call_action(self, action, data, timeout=30): + # Simulate error response for missing run_id + if not data.get("run_id"): + return {"ok": False, "message": "run_id is required"} + return {"ok": True, "data": {}} + + return MockHandler() + + @pytest.mark.asyncio + async def test_artifact_metadata_requires_run_id(self, mock_handler): + """Test artifact_metadata requires run_id.""" + result = await mock_handler.call_action( + "artifact_metadata", + {"run_id": None, "artifact_id": "art_001"}, + ) + + assert result.get("ok") is False or "error" in str(result).lower() + + @pytest.mark.asyncio + async def test_artifact_read_requires_run_id(self, mock_handler): + """Test artifact_read requires run_id.""" + result = await mock_handler.call_action( + "artifact_read", + {"run_id": None, "artifact_id": "art_001"}, + ) + + assert result.get("ok") is False or "error" in str(result).lower() + + +class TestArtifactAccessValidation: + """Test _validate_artifact_access authorization rules.""" + + def _call_validate(self, session, metadata, operation="metadata"): + """Helper to call the validation function.""" + from langbot.pkg.plugin.handler import _validate_artifact_access + return _validate_artifact_access(session, metadata, operation) + + def test_global_artifact_denied_by_default(self): + """Artifacts without conversation_id are denied by default (no global access).""" + session = { + "run_id": "run_001", + "conversation_id": "conv_001", + "permissions": {"artifacts": ["metadata", "read"]}, + } + metadata = { + "artifact_id": "art_global", + "conversation_id": None, # No conversation scope + "run_id": None, # Not created by any run + } + + is_allowed, error = self._call_validate(session, metadata) + assert is_allowed is False + assert "denied" in error.lower() + + def test_own_run_artifact_allowed(self): + """Artifacts created by same run are allowed (even cross-conversation).""" + session = { + "run_id": "run_001", + "conversation_id": "conv_001", + "permissions": {"artifacts": ["metadata", "read"]}, + } + metadata = { + "artifact_id": "art_001", + "conversation_id": "conv_other", # Different conversation + "run_id": "run_001", # Same run + } + + is_allowed, error = self._call_validate(session, metadata) + assert is_allowed is True + assert error is None + + def test_same_conversation_allowed(self): + """Artifacts in same conversation are allowed.""" + session = { + "run_id": "run_001", + "conversation_id": "conv_001", + "permissions": {"artifacts": ["metadata", "read"]}, + } + metadata = { + "artifact_id": "art_001", + "conversation_id": "conv_001", # Same as session + "run_id": "run_other", # Different run + } + + is_allowed, error = self._call_validate(session, metadata) + assert is_allowed is True + assert error is None + + def test_different_conversation_and_run_denied(self): + """Artifacts in different conversation and different run are denied.""" + session = { + "run_id": "run_001", + "conversation_id": "conv_001", + "permissions": {"artifacts": ["metadata", "read"]}, + } + metadata = { + "artifact_id": "art_001", + "conversation_id": "conv_other", # Different conversation + "run_id": "run_other", # Different run + } + + is_allowed, error = self._call_validate(session, metadata) + assert is_allowed is False + assert "denied" in error.lower() + + def test_session_without_conversation_denied_for_conversation_artifact(self): + """Session without conversation_id cannot access conversation-scoped artifacts.""" + session = { + "run_id": "run_001", + "conversation_id": None, # No conversation + "permissions": {"artifacts": ["metadata", "read"]}, + } + metadata = { + "artifact_id": "art_001", + "conversation_id": "conv_001", # Has conversation + "run_id": "run_other", # Different run + } + + is_allowed, error = self._call_validate(session, metadata) + assert is_allowed is False + + def test_session_without_conversation_allowed_for_own_artifact(self): + """Session without conversation can access artifacts it created.""" + session = { + "run_id": "run_001", + "conversation_id": None, # No conversation + "permissions": {"artifacts": ["metadata", "read"]}, + } + metadata = { + "artifact_id": "art_001", + "conversation_id": "conv_001", # Has conversation + "run_id": "run_001", # Same run (created by this run) + } + + is_allowed, error = self._call_validate(session, metadata) + assert is_allowed is True + + +class TestContextAccessArtifactAPIs: + """Test ContextAccess reflects artifact API permissions.""" + + @pytest.mark.asyncio + async def test_context_access_has_artifact_apis_when_permitted(self): + """Test ContextAccess shows artifact APIs when permissions allow.""" + # This tests the context builder logic + # When artifact permissions include 'metadata' and 'read', + # available_apis should reflect that + permissions = {"artifacts": ["metadata", "read"]} + + # Check that permissions are properly interpreted + artifact_metadata_enabled = "metadata" in permissions.get("artifacts", []) + artifact_read_enabled = "read" in permissions.get("artifacts", []) + + assert artifact_metadata_enabled is True + assert artifact_read_enabled is True + + @pytest.mark.asyncio + async def test_context_access_no_artifact_apis_without_permission(self): + """Test ContextAccess hides artifact APIs when permissions denied.""" + permissions = {"artifacts": []} + + artifact_metadata_enabled = "metadata" in permissions.get("artifacts", []) + artifact_read_enabled = "read" in permissions.get("artifacts", []) + + assert artifact_metadata_enabled is False + assert artifact_read_enabled is False + + +class TestArtifactMetadataFieldAlignment: + """Test that Host returns metadata compatible with SDK ArtifactMetadata.""" + + def test_row_to_public_dict_excludes_host_only_fields(self): + """_row_to_public_dict should not return Host-only fields.""" + from langbot.pkg.agent.runner.artifact_store import ArtifactStore + from langbot.pkg.entity.persistence.artifact import AgentArtifact + from unittest.mock import MagicMock + + # Create a mock row + mock_row = MagicMock(spec=AgentArtifact) + mock_row.artifact_id = "art_001" + mock_row.artifact_type = "image" + mock_row.mime_type = "image/png" + mock_row.name = "test.png" + mock_row.size_bytes = 1024 + mock_row.sha256 = "abc123" + mock_row.source = "platform" + mock_row.conversation_id = "conv_001" + mock_row.run_id = "run_001" + mock_row.runner_id = "plugin:test/plugin/runner" + mock_row.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0) + mock_row.expires_at = None + mock_row.metadata_json = None + + # These are Host-only fields that should NOT be in output + # (they don't exist in SDK ArtifactMetadata) + mock_row.bot_id = "bot_001" + mock_row.workspace_id = "ws_001" + mock_row.storage_key = "artifact:art_001" + mock_row.storage_type = "binary_storage" + + store = ArtifactStore(MagicMock()) + result = store._row_to_public_dict(mock_row) + + # SDK-compatible fields should be present + assert result["artifact_id"] == "art_001" + assert result["artifact_type"] == "image" + assert result["source"] == "platform" + assert result["conversation_id"] == "conv_001" + assert result["run_id"] == "run_001" + + # Host-only fields should NOT be present + assert "bot_id" not in result + assert "workspace_id" not in result + assert "storage_key" not in result + assert "storage_type" not in result + + +class TestSessionRegistryPermissions: + """Test that session registry stores and retrieves permissions correctly.""" + + @pytest.fixture + def session_registry(self): + """Create a fresh session registry for testing.""" + import langbot.pkg.agent.runner.session_registry as reg + reg._global_registry = None + return get_session_registry() + + @pytest.mark.asyncio + async def test_register_stores_permissions(self, session_registry): + """Test that register() stores permissions from descriptor.""" + await session_registry.register( + run_id="run_001", + runner_id="plugin:author/plugin/runner", + query_id=None, + plugin_identity="author/plugin", + resources={ + "models": [], + "tools": [], + "knowledge_bases": [], + "files": [], + "storage": {"plugin_storage": True, "workspace_storage": False}, + "platform_capabilities": {}, + }, + permissions={ + "artifacts": ["metadata", "read"], + "history": ["page"], + "events": ["get"], + }, + conversation_id="conv_001", + ) + + session = await session_registry.get("run_001") + assert session is not None + assert session["permissions"]["artifacts"] == ["metadata", "read"] + assert session["permissions"]["history"] == ["page"] + assert session["permissions"]["events"] == ["get"] + + @pytest.mark.asyncio + async def test_register_with_empty_permissions(self, session_registry): + """Test that register() handles empty permissions.""" + await session_registry.register( + run_id="run_002", + runner_id="plugin:author/plugin/runner", + query_id=None, + plugin_identity="author/plugin", + resources={ + "models": [], + "tools": [], + "knowledge_bases": [], + "files": [], + "storage": {"plugin_storage": True, "workspace_storage": False}, + "platform_capabilities": {}, + }, + permissions={}, + conversation_id="conv_001", + ) + + session = await session_registry.get("run_002") + assert session is not None + assert session["permissions"] == {} + + +class TestArtifactStoreRealSQLite: + """Test ArtifactStore with real SQLite database.""" + + @pytest.fixture + async def db_engine(self): + """Create an in-memory SQLite database for testing.""" + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy import text + from langbot.pkg.entity.persistence.base import Base + from langbot.pkg.entity.persistence.artifact import AgentArtifact + from langbot.pkg.entity.persistence.bstorage import BinaryStorage + + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + + # Create tables + async with engine.begin() as conn: + # Create tables manually for in-memory DB + await conn.run_sync(Base.metadata.create_all) + + yield engine + + await engine.dispose() + + @pytest.mark.asyncio + async def test_register_get_metadata_round_trip(self, db_engine): + """Test register_artifact -> get_metadata round trip with real DB.""" + store = ArtifactStore(db_engine) + + # Register artifact with content + content = b"test image content for round trip" + artifact_id = await store.register_artifact( + artifact_id="art_real_001", + artifact_type="image", + source="platform", + mime_type="image/png", + name="test.png", + content=content, + conversation_id="conv_001", + run_id="run_001", + ) + + assert artifact_id == "art_real_001" + + # Get metadata + metadata = await store.get_metadata(artifact_id) + assert metadata is not None + assert metadata["artifact_id"] == "art_real_001" + assert metadata["artifact_type"] == "image" + assert metadata["mime_type"] == "image/png" + assert metadata["source"] == "platform" + assert metadata["conversation_id"] == "conv_001" + assert metadata["run_id"] == "run_001" + + # Verify Host-only fields are NOT in public metadata + assert "storage_key" not in metadata + assert "storage_type" not in metadata + assert "bot_id" not in metadata + assert "workspace_id" not in metadata + + @pytest.mark.asyncio + async def test_read_artifact_round_trip(self, db_engine): + """Test register_artifact -> read_artifact round trip with real DB.""" + store = ArtifactStore(db_engine) + + # Register artifact with content + content = b"test file content for read test" + artifact_id = await store.register_artifact( + artifact_id="art_real_002", + artifact_type="file", + source="runner", + mime_type="text/plain", + name="test.txt", + content=content, + conversation_id="conv_001", + run_id="run_001", + ) + + # Read artifact + result = await store.read_artifact(artifact_id) + assert result is not None + assert result["artifact_id"] == "art_real_002" + assert result["mime_type"] == "text/plain" + assert result["offset"] == 0 + assert result["length"] == len(content) + assert result["has_more"] is False + + # Verify content + decoded_content = base64.b64decode(result["content_base64"]) + assert decoded_content == content + + @pytest.mark.asyncio + async def test_read_artifact_with_offset_limit(self, db_engine): + """Test read_artifact with offset and limit.""" + store = ArtifactStore(db_engine) + + # Register artifact with content + content = b"0123456789" * 100 # 1000 bytes + artifact_id = await store.register_artifact( + artifact_id="art_real_003", + artifact_type="file", + source="runner", + mime_type="application/octet-stream", + content=content, + ) + + # Read with offset + result = await store.read_artifact(artifact_id, offset=100, limit=100) + assert result is not None + assert result["offset"] == 100 + assert result["length"] == 100 + + # Verify content + decoded_content = base64.b64decode(result["content_base64"]) + assert decoded_content == content[100:200] + + @pytest.mark.asyncio + async def test_read_artifact_has_more(self, db_engine): + """Test read_artifact sets has_more correctly.""" + store = ArtifactStore(db_engine) + + # Register artifact with content + content = b"0123456789" * 100 # 1000 bytes + artifact_id = await store.register_artifact( + artifact_id="art_real_004", + artifact_type="file", + source="runner", + content=content, + ) + + # Read with limit smaller than content + result = await store.read_artifact(artifact_id, offset=0, limit=100) + assert result is not None + assert result["has_more"] is True + assert result["length"] == 100 + + @pytest.mark.asyncio + async def test_metadata_sdk_validation(self, db_engine): + """Test that metadata can be validated by SDK ArtifactMetadata.""" + from langbot_plugin.api.entities.builtin.agent_runner.artifact import ArtifactMetadata + + store = ArtifactStore(db_engine) + + # Register artifact + artifact_id = await store.register_artifact( + artifact_id="art_real_005", + artifact_type="file", + source="runner", + mime_type="application/pdf", + name="document.pdf", + size_bytes=1024, + conversation_id="conv_001", + run_id="run_001", + runner_id="plugin:test/plugin/runner", + ) + + # Get metadata + metadata = await store.get_metadata(artifact_id) + assert metadata is not None + + # Should not raise ValidationError + validated = ArtifactMetadata.model_validate(metadata) + assert validated.artifact_id == "art_real_005" + assert validated.artifact_type == "file" diff --git a/tests/unit_tests/agent/test_event_log_transcript.py b/tests/unit_tests/agent/test_event_log_transcript.py index 4b1d5fb1..e6fb8d6a 100644 --- a/tests/unit_tests/agent/test_event_log_transcript.py +++ b/tests/unit_tests/agent/test_event_log_transcript.py @@ -73,49 +73,78 @@ class TestEventLogStore: @pytest.mark.asyncio async def test_append_event(self, mock_db_engine): """Test appending an event to EventLog.""" + from unittest.mock import AsyncMock, MagicMock, patch + store = EventLogStore(mock_db_engine) - event_id = await store.append_event( - event_id="evt_1", - event_type="message.received", - source="platform", - bot_id="bot_1", - conversation_id="conv_1", - actor_type="user", - actor_id="user_1", - input_summary="Hello world", - run_id="run_1", - runner_id="plugin:test/plugin/runner", - ) + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() - assert event_id == "evt_1" + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + event_id = await store.append_event( + event_id="evt_1", + event_type="message.received", + source="platform", + bot_id="bot_1", + conversation_id="conv_1", + actor_type="user", + actor_id="user_1", + input_summary="Hello world", + run_id="run_1", + runner_id="plugin:test/plugin/runner", + ) + + assert event_id == "evt_1" @pytest.mark.asyncio async def test_append_event_truncates_input_summary(self, mock_db_engine): """Test that long input summaries are truncated.""" + from unittest.mock import AsyncMock, MagicMock, patch + store = EventLogStore(mock_db_engine) - long_text = "x" * 2000 - event_id = await store.append_event( - event_id="evt_2", - event_type="message.received", - source="platform", - input_summary=long_text, - ) + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() - assert event_id == "evt_2" + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + long_text = "x" * 2000 + event_id = await store.append_event( + event_id="evt_2", + event_type="message.received", + source="platform", + input_summary=long_text, + ) + + assert event_id == "evt_2" @pytest.mark.asyncio async def test_page_events_with_conversation_filter(self, mock_db_engine): """Test paging events with conversation_id filter.""" + from unittest.mock import AsyncMock, MagicMock, patch + store = EventLogStore(mock_db_engine) - items, next_seq, has_more = await store.page_events( - conversation_id="conv_1", - limit=10, - ) + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] - assert isinstance(items, list) + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + items, next_seq, has_more = await store.page_events( + conversation_id="conv_1", + limit=10, + ) + + assert isinstance(items, list) class TestTranscriptStore: @@ -124,75 +153,129 @@ class TestTranscriptStore: @pytest.mark.asyncio async def test_append_transcript(self, mock_db_engine): """Test appending a transcript item.""" + from unittest.mock import AsyncMock, MagicMock, patch + store = TranscriptStore(mock_db_engine) - transcript_id = await store.append_transcript( - transcript_id=None, # Auto-generate - event_id="evt_1", - conversation_id="conv_1", - role="user", - content="Hello", - ) + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() - assert transcript_id is not None + # Mock _get_next_seq + with patch.object(store, '_get_next_seq', return_value=1): + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + transcript_id = await store.append_transcript( + transcript_id=None, # Auto-generate + event_id="evt_1", + conversation_id="conv_1", + role="user", + content="Hello", + ) + + assert transcript_id is not None @pytest.mark.asyncio async def test_append_transcript_with_artifacts(self, mock_db_engine): """Test appending transcript with artifact refs.""" + from unittest.mock import AsyncMock, MagicMock, patch + store = TranscriptStore(mock_db_engine) - transcript_id = await store.append_transcript( - transcript_id=None, # Auto-generate - event_id="evt_2", - conversation_id="conv_1", - role="assistant", - content="Here's an image", - artifact_refs=[ - {"artifact_id": "art_1", "artifact_type": "image", "url": "http://example.com/img.png"} - ], - ) + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() - assert transcript_id is not None + with patch.object(store, '_get_next_seq', return_value=1): + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + transcript_id = await store.append_transcript( + transcript_id=None, # Auto-generate + event_id="evt_2", + conversation_id="conv_1", + role="assistant", + content="Here's an image", + artifact_refs=[ + {"artifact_id": "art_1", "artifact_type": "image", "url": "http://example.com/img.png"} + ], + ) + + assert transcript_id is not None @pytest.mark.asyncio async def test_page_transcript_backward(self, mock_db_engine): """Test paging transcript backward (older items).""" + from unittest.mock import AsyncMock, MagicMock, patch + store = TranscriptStore(mock_db_engine) - items, next_seq, prev_seq, has_more = await store.page_transcript( - conversation_id="conv_1", - limit=10, - direction="backward", - ) + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] - assert isinstance(items, list) + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + items, next_seq, prev_seq, has_more = await store.page_transcript( + conversation_id="conv_1", + limit=10, + direction="backward", + ) + + assert isinstance(items, list) @pytest.mark.asyncio async def test_page_transcript_has_hard_limit(self, mock_db_engine): """Test that transcript paging has a hard limit.""" + from unittest.mock import AsyncMock, MagicMock, patch + store = TranscriptStore(mock_db_engine) - # Request more than the hard limit - items, next_seq, prev_seq, has_more = await store.page_transcript( - conversation_id="conv_1", - limit=200, # Request 200, but hard limit is 100 - ) + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] - # The store should cap at 100 - assert len(items) <= store.HARD_LIMIT + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + # Request more than the hard limit + items, next_seq, prev_seq, has_more = await store.page_transcript( + conversation_id="conv_1", + limit=200, # Request 200, but hard limit is 100 + ) + + # The store should cap at 100 + assert len(items) <= store.HARD_LIMIT @pytest.mark.asyncio async def test_search_transcript(self, mock_db_engine): """Test searching transcript.""" + from unittest.mock import AsyncMock, MagicMock, patch + store = TranscriptStore(mock_db_engine) - items = await store.search_transcript( - conversation_id="conv_1", - query_text="database", - top_k=10, - ) + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] - assert isinstance(items, list) + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + items = await store.search_transcript( + conversation_id="conv_1", + query_text="database", + top_k=10, + ) + + assert isinstance(items, list) class TestHistoryPageAuthorization: @@ -259,50 +342,244 @@ class TestContextAccessPopulation: @pytest.mark.asyncio async def test_context_access_has_history_apis_when_permitted(self, mock_db_engine): """Test ContextAccess shows available APIs based on permissions.""" - # This would test the context builder logic - # For now we verify the store methods work + from unittest.mock import AsyncMock, MagicMock, patch + store = TranscriptStore(mock_db_engine) - cursor = await store.get_latest_cursor("conv_1") - # Should return None or a cursor string - assert cursor is None or isinstance(cursor, str) + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + cursor = await store.get_latest_cursor("conv_1") + # Should return None or a cursor string + assert cursor is None or isinstance(cursor, str) @pytest.mark.asyncio async def test_context_access_shows_has_history_before(self, mock_db_engine): """Test ContextAccess indicates if history exists.""" + from unittest.mock import AsyncMock, MagicMock, patch + store = TranscriptStore(mock_db_engine) - has_history = await store.has_history_before("conv_1", 10) - assert isinstance(has_history, bool) + mock_result = MagicMock() + mock_result.scalar.return_value = 0 + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch.object(store, '_session_factory') as mock_factory: + mock_factory.return_value.__aenter__.return_value = mock_session + + has_history = await store.has_history_before("conv_1", 10) + assert isinstance(has_history, bool) + + +class TestEventLogStoreRealSQLite: + """Test EventLogStore with real SQLite database.""" + + @pytest.fixture + async def db_engine(self): + """Create an in-memory SQLite database for testing.""" + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy import text + from langbot.pkg.entity.persistence.base import Base + from langbot.pkg.entity.persistence.event_log import EventLog + + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + + # Create tables + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + await engine.dispose() + + @pytest.mark.asyncio + async def test_append_get_event_round_trip(self, db_engine): + """Test append_event -> get_event round trip with real DB.""" + store = EventLogStore(db_engine) + + # Append event + event_id = await store.append_event( + event_id="evt_real_001", + event_type="message.received", + source="platform", + bot_id="bot_001", + conversation_id="conv_001", + actor_type="user", + actor_id="user_001", + actor_name="Test User", + input_summary="Hello world", + run_id="run_001", + runner_id="plugin:test/plugin/runner", + ) + + assert event_id == "evt_real_001" + + # Get event + event = await store.get_event(event_id) + assert event is not None + assert event["event_id"] == "evt_real_001" + assert event["event_type"] == "message.received" + assert event["source"] == "platform" + assert event["conversation_id"] == "conv_001" + assert event["actor_type"] == "user" + assert event["actor_id"] == "user_001" + + @pytest.mark.asyncio + async def test_page_events(self, db_engine): + """Test page_events with real DB.""" + store = EventLogStore(db_engine) + + # Append multiple events + for i in range(5): + await store.append_event( + event_id=f"evt_real_{i:03d}", + event_type="message.received", + source="platform", + conversation_id="conv_001", + input_summary=f"Message {i}", + ) + + # Page events + items, next_seq, has_more = await store.page_events( + conversation_id="conv_001", + limit=3, + ) + + assert len(items) == 3 + assert has_more is True + + @pytest.mark.asyncio + async def test_get_latest_cursor(self, db_engine): + """Test get_latest_cursor with real DB.""" + store = EventLogStore(db_engine) + + # Append events + for i in range(3): + await store.append_event( + event_id=f"evt_cursor_{i:03d}", + event_type="message.received", + source="platform", + conversation_id="conv_cursor", + ) + + # Get latest cursor + cursor = await store.get_latest_cursor("conv_cursor") + assert cursor is not None + assert int(cursor) > 0 + + +class TestTranscriptStoreRealSQLite: + """Test TranscriptStore with real SQLite database.""" + + @pytest.fixture + async def db_engine(self): + """Create an in-memory SQLite database for testing.""" + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy import text + from langbot.pkg.entity.persistence.base import Base + from langbot.pkg.entity.persistence.transcript import Transcript + + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + + # Create tables + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + await engine.dispose() + + @pytest.mark.asyncio + async def test_append_page_transcript_round_trip(self, db_engine): + """Test append_transcript -> page_transcript round trip with real DB.""" + store = TranscriptStore(db_engine) + + # Append transcript items + for i in range(3): + await store.append_transcript( + transcript_id=f"trans_real_{i:03d}", + event_id=f"evt_{i:03d}", + conversation_id="conv_001", + role="user" if i % 2 == 0 else "assistant", + content=f"Message {i}", + ) + + # Page transcript + items, next_seq, prev_seq, has_more = await store.page_transcript( + conversation_id="conv_001", + limit=10, + ) + + assert len(items) == 3 + assert items[0]["conversation_id"] == "conv_001" + + @pytest.mark.asyncio + async def test_search_transcript_real_db(self, db_engine): + """Test search_transcript with real DB.""" + store = TranscriptStore(db_engine) + + # Append transcript items + await store.append_transcript( + transcript_id="trans_search_001", + event_id="evt_search_001", + conversation_id="conv_search", + role="user", + content="I want to learn about databases", + ) + await store.append_transcript( + transcript_id="trans_search_002", + event_id="evt_search_002", + conversation_id="conv_search", + role="assistant", + content="Here is information about databases", + ) + + # Search for "database" + items = await store.search_transcript( + conversation_id="conv_search", + query_text="database", + ) + + # Should find at least one match + assert len(items) >= 1 + + @pytest.mark.asyncio + async def test_get_latest_cursor_real_db(self, db_engine): + """Test get_latest_cursor with real DB.""" + store = TranscriptStore(db_engine) + + # Append transcript items + for i in range(3): + await store.append_transcript( + transcript_id=f"trans_cursor_{i:03d}", + event_id=f"evt_cursor_{i:03d}", + conversation_id="conv_cursor", + role="user", + content=f"Message {i}", + ) + + # Get latest cursor + cursor = await store.get_latest_cursor("conv_cursor") + assert cursor is not None + assert int(cursor) > 0 # Fixtures @pytest.fixture def mock_db_engine(): - """Create a mock database engine.""" - from unittest.mock import MagicMock, AsyncMock + """Create a mock database engine for AsyncSession-based stores.""" + from unittest.mock import MagicMock from sqlalchemy.ext.asyncio import AsyncEngine engine = MagicMock(spec=AsyncEngine) - - # Mock connection - mock_conn = MagicMock() - mock_result = MagicMock() - mock_result.fetchone.return_value = None - mock_result.fetchall.return_value = [] - mock_result.scalar.return_value = 0 - mock_conn.execute = AsyncMock(return_value=mock_result) - mock_conn.commit = AsyncMock() - - # Create async context manager for connect() - class AsyncConnectContextManager: - async def __aenter__(self): - return mock_conn - async def __aexit__(self, *args): - pass - - # connect() should return an async context manager - engine.connect = MagicMock(return_value=AsyncConnectContextManager()) return engine