feat(agent-runner): add artifact store pull APIs

This commit is contained in:
huanghuoguoguo
2026-05-23 17:29:18 +08:00
parent 8db23bf950
commit e0e321251e
12 changed files with 1728 additions and 170 deletions

View File

@@ -0,0 +1,300 @@
"""Artifact store for managing Host-owned artifacts."""
from __future__ import annotations
import json
import datetime
import typing
import uuid
import base64
import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import sessionmaker
from ...entity.persistence.artifact import AgentArtifact
from ...entity.persistence.bstorage import BinaryStorage
class ArtifactStore:
"""Store for AgentArtifact records.
Handles artifact metadata registration and content retrieval.
Actual blob storage is delegated to BinaryStorage or external storage.
All methods are async and use the provided database engine.
"""
engine: AsyncEngine
# Hard limits
MAX_INLINE_READ_BYTES = 1024 * 1024 # 1MB max for inline base64
MAX_RANGE_READ_BYTES = 10 * 1024 * 1024 # 10MB max for range reads
def __init__(self, engine: AsyncEngine):
self.engine = engine
self._session_factory = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async def register_artifact(
self,
artifact_id: str | None,
artifact_type: str,
source: str,
storage_key: str | None = None,
storage_type: str = 'binary_storage',
mime_type: str | None = None,
name: str | None = None,
size_bytes: int | None = None,
sha256: str | None = None,
conversation_id: str | None = None,
run_id: str | None = None,
runner_id: str | None = None,
bot_id: str | None = None,
workspace_id: str | None = None,
expires_at: datetime.datetime | None = None,
metadata: dict[str, typing.Any] | None = None,
content: bytes | None = None,
) -> str:
"""Register a new artifact.
If content is provided and storage_key is None, stores content
in BinaryStorage automatically.
Args:
artifact_id: Unique artifact ID (generated if None)
artifact_type: Type of artifact (image, file, voice, tool_result, etc.)
source: Source of artifact (platform, runner, tool, system)
storage_key: Key in BinaryStorage or external reference
storage_type: Storage type (binary_storage, file, url)
mime_type: MIME type
name: Original file name
size_bytes: Size in bytes
sha256: SHA256 hash
conversation_id: Conversation ID
run_id: Run ID that created this
runner_id: Runner ID that created this
bot_id: Bot UUID
workspace_id: Workspace ID
expires_at: Expiration time
metadata: Additional metadata
content: Optional content to store in BinaryStorage
Returns:
The artifact_id
"""
if artifact_id is None:
artifact_id = str(uuid.uuid4())
# If content provided, store in BinaryStorage
if content is not None and storage_key is None:
storage_key = f"artifact:{artifact_id}"
storage_type = 'binary_storage'
if size_bytes is None:
size_bytes = len(content)
async with self._session_factory() as session:
# Store content in BinaryStorage if provided
if content is not None:
binary_storage = BinaryStorage(
unique_key=f'artifact:{artifact_id}',
key=storage_key,
owner_type='artifact',
owner='host',
value=content,
)
session.add(binary_storage)
# Store artifact metadata
artifact = AgentArtifact(
artifact_id=artifact_id,
artifact_type=artifact_type,
mime_type=mime_type,
name=name,
size_bytes=size_bytes,
sha256=sha256,
source=source,
storage_key=storage_key,
storage_type=storage_type,
conversation_id=conversation_id,
run_id=run_id,
runner_id=runner_id,
bot_id=bot_id,
workspace_id=workspace_id,
created_at=datetime.datetime.utcnow(),
expires_at=expires_at,
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(artifact)
await session.commit()
return artifact_id
async def get_metadata(
self,
artifact_id: str,
) -> dict[str, typing.Any] | None:
"""Get artifact metadata (public fields only, no internal storage info).
Args:
artifact_id: Artifact ID
Returns:
Artifact metadata dict compatible with SDK ArtifactMetadata, or None if not found
"""
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(AgentArtifact).where(
AgentArtifact.artifact_id == artifact_id
)
)
row = result.scalars().first()
if row is None:
return None
return self._row_to_public_dict(row)
async def _get_internal_record(
self,
artifact_id: str,
) -> AgentArtifact | None:
"""Get full artifact record including internal fields.
Used internally by read_artifact to access storage_key/storage_type.
Args:
artifact_id: Artifact ID
Returns:
AgentArtifact ORM instance, or None if not found
"""
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(AgentArtifact).where(
AgentArtifact.artifact_id == artifact_id
)
)
return result.scalars().first()
async def read_artifact(
self,
artifact_id: str,
offset: int = 0,
limit: int | None = None,
) -> dict[str, typing.Any] | None:
"""Read artifact content.
For small artifacts, returns content_base64 directly.
For large artifacts, returns file_key for chunked transfer.
Args:
artifact_id: Artifact ID
offset: Byte offset to start reading from (must be >= 0)
limit: Maximum bytes to read (must be > 0 if provided)
Returns:
ArtifactReadResult dict, or None if not found
Raises:
ValueError: If offset < 0 or limit <= 0
"""
# Validate offset and limit
if offset < 0:
raise ValueError("offset must be >= 0")
if limit is not None and limit <= 0:
raise ValueError("limit must be > 0")
# Get internal record (includes storage_key/storage_type)
record = await self._get_internal_record(artifact_id)
if record is None:
return None
storage_type = record.storage_type or 'binary_storage'
storage_key = record.storage_key
size_bytes = record.size_bytes or 0
# Cap limit at hard limit
if limit is None:
limit = self.MAX_INLINE_READ_BYTES
limit = min(limit, self.MAX_RANGE_READ_BYTES)
# For binary_storage, read content
if storage_type == 'binary_storage' and storage_key:
content = await self._read_binary_storage(storage_key)
if content is None:
return None
# Apply offset and limit
if offset > 0:
content = content[offset:]
if limit and len(content) > limit:
content = content[:limit]
has_more = True
else:
has_more = False
return {
'artifact_id': artifact_id,
'mime_type': record.mime_type,
'size_bytes': size_bytes,
'offset': offset,
'length': len(content),
'content_base64': base64.b64encode(content).decode('utf-8'),
'file_key': None,
'has_more': has_more,
}
# For other storage types, return storage reference
# (caller can use file_key for chunked transfer)
return {
'artifact_id': artifact_id,
'mime_type': record.mime_type,
'size_bytes': size_bytes,
'offset': offset,
'length': None,
'content_base64': None,
'file_key': storage_key,
'has_more': False,
}
async def _read_binary_storage(self, key: str) -> bytes | None:
"""Read content from BinaryStorage.
Uses unique_key for isolation to prevent cross-artifact access.
Args:
key: The unique_key used when storing the artifact
Returns:
Content bytes, or None if not found
"""
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(BinaryStorage).where(BinaryStorage.unique_key == key)
)
row = result.scalars().first()
if row is None:
return None
return row.value
def _row_to_public_dict(self, row: AgentArtifact) -> dict[str, typing.Any]:
"""Convert an AgentArtifact row to public dict.
Returns only fields that match SDK ArtifactMetadata entity.
Host-only fields (bot_id, workspace_id, storage_key, storage_type) are excluded.
"""
return {
'artifact_id': row.artifact_id,
'artifact_type': row.artifact_type,
'mime_type': row.mime_type,
'name': row.name,
'size_bytes': row.size_bytes,
'sha256': row.sha256,
'source': row.source,
'conversation_id': row.conversation_id,
'run_id': row.run_id,
'runner_id': row.runner_id,
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
'expires_at': int(row.expires_at.timestamp()) if row.expires_at else None,
'metadata': json.loads(row.metadata_json) if row.metadata_json else {},
}

View File

@@ -891,11 +891,14 @@ class AgentRunContextBuilder:
permissions = descriptor.permissions or {}
history_permissions = permissions.get('history', [])
event_permissions = permissions.get('events', [])
artifact_permissions = permissions.get('artifacts', [])
history_page_enabled = 'page' in history_permissions and conversation_id is not None
history_search_enabled = 'search' in history_permissions and conversation_id is not None
event_get_enabled = 'get' in event_permissions
event_page_enabled = 'page' in event_permissions and conversation_id is not None
artifact_metadata_enabled = 'metadata' in artifact_permissions
artifact_read_enabled = 'read' in artifact_permissions
# Get latest cursor and has_history_before if conversation exists
latest_cursor = None
@@ -931,8 +934,8 @@ class AgentRunContextBuilder:
'history_search': history_search_enabled,
'event_get': event_get_enabled,
'event_page': event_page_enabled,
'artifact_metadata': False, # TODO: Implement artifact store
'artifact_read': False,
'artifact_metadata': artifact_metadata_enabled,
'artifact_read': artifact_read_enabled,
'state': True,
'storage': True,
},

View File

@@ -7,7 +7,8 @@ import typing
import uuid
import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import sessionmaker
from ...entity.persistence.event_log import EventLog
from ...entity.persistence.transcript import Transcript
@@ -27,6 +28,9 @@ class EventLogStore:
def __init__(self, engine: AsyncEngine):
self.engine = engine
self._session_factory = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async def append_event(
self,
@@ -83,32 +87,31 @@ class EventLogStore:
if input_summary and len(input_summary) > self.MAX_INPUT_SUMMARY_LENGTH:
input_summary = input_summary[:self.MAX_INPUT_SUMMARY_LENGTH - 3] + "..."
async with self.engine.connect() as conn:
await conn.execute(
sqlalchemy.insert(EventLog).values(
event_id=event_id,
event_type=event_type,
event_time=event_time,
source=source,
bot_id=bot_id,
workspace_id=workspace_id,
conversation_id=conversation_id,
thread_id=thread_id,
actor_type=actor_type,
actor_id=actor_id,
actor_name=actor_name,
subject_type=subject_type,
subject_id=subject_id,
input_summary=input_summary,
input_json=json.dumps(input_json) if input_json else None,
raw_ref=raw_ref,
run_id=run_id,
runner_id=runner_id,
metadata_json=json.dumps(metadata) if metadata else None,
created_at=datetime.datetime.utcnow(),
)
async with self._session_factory() as session:
event = EventLog(
event_id=event_id,
event_type=event_type,
event_time=event_time,
source=source,
bot_id=bot_id,
workspace_id=workspace_id,
conversation_id=conversation_id,
thread_id=thread_id,
actor_type=actor_type,
actor_id=actor_id,
actor_name=actor_name,
subject_type=subject_type,
subject_id=subject_id,
input_summary=input_summary,
input_json=json.dumps(input_json) if input_json else None,
raw_ref=raw_ref,
run_id=run_id,
runner_id=runner_id,
metadata_json=json.dumps(metadata) if metadata else None,
created_at=datetime.datetime.utcnow(),
)
await conn.commit()
session.add(event)
await session.commit()
return event_id
@@ -124,14 +127,14 @@ class EventLogStore:
Returns:
Event record as dict, or None if not found
"""
async with self.engine.connect() as conn:
result = await conn.execute(
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(EventLog).where(EventLog.event_id == event_id)
)
row = result.fetchone()
row = result.scalars().first()
if row is None:
return None
return self._row_to_dict(row[0])
return self._row_to_dict(row)
async def page_events(
self,
@@ -153,7 +156,7 @@ class EventLogStore:
"""
limit = min(limit, 100) # Hard cap
async with self.engine.connect() as conn:
async with self._session_factory() as session:
query = sqlalchemy.select(EventLog)
if conversation_id is not None:
@@ -167,10 +170,10 @@ class EventLogStore:
query = query.order_by(EventLog.id.desc()).limit(limit + 1)
result = await conn.execute(query)
rows = result.fetchall()
result = await session.execute(query)
rows = result.scalars().all()
items = [self._row_to_dict(row[0]) for row in rows[:limit]]
items = [self._row_to_dict(row) for row in rows[:limit]]
has_more = len(rows) > limit
next_seq = items[-1]['id'] if items and has_more else None
@@ -188,17 +191,17 @@ class EventLogStore:
Returns:
Cursor string (seq number), or None if no events
"""
async with self.engine.connect() as conn:
result = await conn.execute(
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(EventLog.id)
.where(EventLog.conversation_id == conversation_id)
.order_by(EventLog.id.desc())
.limit(1)
)
row = result.fetchone()
row = result.scalars().first()
if row is None:
return None
return str(row[0])
return str(row)
async def has_events_before(
self,
@@ -214,8 +217,8 @@ class EventLogStore:
Returns:
True if there are events before
"""
async with self.engine.connect() as conn:
result = await conn.execute(
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(sqlalchemy.func.count())
.select_from(EventLog)
.where(

View File

@@ -125,6 +125,7 @@ class AgentRunOrchestrator:
query_id=None, # No query_id in event-first mode
plugin_identity=descriptor.get_plugin_id(),
resources=resources,
permissions=descriptor.permissions or {},
conversation_id=event.conversation_id,
)
@@ -222,6 +223,7 @@ class AgentRunOrchestrator:
query_id=query.query_id,
plugin_identity=descriptor.get_plugin_id(),
resources=resources,
permissions=descriptor.permissions or {},
conversation_id=conversation_id,
)

View File

@@ -27,6 +27,7 @@ class AgentRunSession(typing.TypedDict):
plugin_identity: Plugin identifier (author/name) of the runner
conversation_id: Conversation ID for history/event access
resources: Authorized resources for this run (from AgentResources)
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
status: Session status tracking
_authorized_ids: Pre-computed authorized resource IDs for O(1) lookup
"""
@@ -36,6 +37,7 @@ class AgentRunSession(typing.TypedDict):
plugin_identity: str # author/name
conversation_id: str | None
resources: AgentResources
permissions: dict[str, list[str]]
status: AgentRunSessionStatus
_authorized_ids: dict[str, set[str]] # Pre-computed sets for O(1) lookup
@@ -67,6 +69,7 @@ class AgentRunSessionRegistry:
plugin_identity: str,
resources: AgentResources,
conversation_id: str | None = None,
permissions: dict[str, list[str]] | None = None,
) -> None:
"""Register a new agent run session.
@@ -77,9 +80,13 @@ class AgentRunSessionRegistry:
plugin_identity: Plugin identifier (author/name)
resources: Authorized resources for this run
conversation_id: Conversation ID for history/event access
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
"""
now = int(time.time())
# Normalize permissions to empty dict if None
permissions = permissions or {}
# Pre-compute authorized resource IDs for O(1) lookup
authorized_ids: dict[str, set[str]] = {
'model': {m.get('model_id') for m in resources.get('models', [])},
@@ -95,6 +102,7 @@ class AgentRunSessionRegistry:
'plugin_identity': plugin_identity,
'conversation_id': conversation_id,
'resources': resources,
'permissions': permissions,
'status': {
'started_at': now,
'last_activity_at': now,

View File

@@ -7,7 +7,8 @@ import typing
import uuid
import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import sessionmaker
from ...entity.persistence.transcript import Transcript
@@ -27,6 +28,9 @@ class TranscriptStore:
def __init__(self, engine: AsyncEngine):
self.engine = engine
self._session_factory = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async def append_transcript(
self,
@@ -72,26 +76,25 @@ class TranscriptStore:
# Get next sequence number for this conversation
seq = await self._get_next_seq(conversation_id)
async with self.engine.connect() as conn:
await conn.execute(
sqlalchemy.insert(Transcript).values(
transcript_id=transcript_id,
event_id=event_id,
conversation_id=conversation_id,
thread_id=thread_id,
role=role,
item_type=item_type,
content=content,
content_json=json.dumps(content_json) if content_json else None,
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
seq=seq,
run_id=run_id,
runner_id=runner_id,
created_at=datetime.datetime.utcnow(),
metadata_json=json.dumps(metadata) if metadata else None,
)
async with self._session_factory() as session:
item = Transcript(
transcript_id=transcript_id,
event_id=event_id,
conversation_id=conversation_id,
thread_id=thread_id,
role=role,
item_type=item_type,
content=content,
content_json=json.dumps(content_json) if content_json else None,
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
seq=seq,
run_id=run_id,
runner_id=runner_id,
created_at=datetime.datetime.utcnow(),
metadata_json=json.dumps(metadata) if metadata else None,
)
await conn.commit()
session.add(item)
await session.commit()
return transcript_id
@@ -119,7 +122,7 @@ class TranscriptStore:
"""
limit = min(limit, self.HARD_LIMIT)
async with self.engine.connect() as conn:
async with self._session_factory() as session:
query = sqlalchemy.select(Transcript).where(
Transcript.conversation_id == conversation_id
)
@@ -136,10 +139,10 @@ class TranscriptStore:
query = query.limit(limit + 1)
result = await conn.execute(query)
rows = result.fetchall()
result = await session.execute(query)
rows = result.scalars().all()
items = [self._row_to_dict(row[0], include_artifacts) for row in rows[:limit]]
items = [self._row_to_dict(row, include_artifacts) for row in rows[:limit]]
has_more = len(rows) > limit
# Calculate cursors
@@ -179,7 +182,7 @@ class TranscriptStore:
Returns:
List of matching items
"""
async with self.engine.connect() as conn:
async with self._session_factory() as session:
query = sqlalchemy.select(Transcript).where(
Transcript.conversation_id == conversation_id,
Transcript.content.ilike(f"%{query_text}%"),
@@ -194,10 +197,10 @@ class TranscriptStore:
query = query.order_by(Transcript.seq.desc()).limit(top_k)
result = await conn.execute(query)
rows = result.fetchall()
result = await session.execute(query)
rows = result.scalars().all()
return [self._row_to_dict(row[0], include_artifacts=True) for row in rows]
return [self._row_to_dict(row, include_artifacts=True) for row in rows]
async def get_latest_cursor(
self,
@@ -211,17 +214,17 @@ class TranscriptStore:
Returns:
Cursor string (seq number), or None if no items
"""
async with self.engine.connect() as conn:
result = await conn.execute(
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(Transcript.seq)
.where(Transcript.conversation_id == conversation_id)
.order_by(Transcript.seq.desc())
.limit(1)
)
row = result.fetchone()
row = result.scalars().first()
if row is None:
return None
return str(row[0])
return str(row)
async def has_history_before(
self,
@@ -237,8 +240,8 @@ class TranscriptStore:
Returns:
True if there are items before
"""
async with self.engine.connect() as conn:
result = await conn.execute(
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(sqlalchemy.func.count())
.select_from(Transcript)
.where(
@@ -251,8 +254,8 @@ class TranscriptStore:
async def _get_next_seq(self, conversation_id: str) -> int:
"""Get the next sequence number for a conversation."""
async with self.engine.connect() as conn:
result = await conn.execute(
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(sqlalchemy.func.max(Transcript.seq))
.where(Transcript.conversation_id == conversation_id)
)

View File

@@ -0,0 +1,77 @@
"""Artifact persistence entity for Host-owned artifact store."""
from __future__ import annotations
import sqlalchemy
import datetime
from .base import Base
class AgentArtifact(Base):
"""AgentArtifact stores metadata for large files, images, tool results, etc.
This table only stores metadata. The actual blob content is stored in
BinaryStorage or external storage, referenced by storage_key.
Artifacts are accessed via artifact_metadata and artifact_read APIs
with run_id authorization.
"""
__tablename__ = 'agent_artifact'
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
"""Auto-increment ID for sequencing."""
artifact_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True)
"""Unique artifact identifier."""
artifact_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False)
"""Artifact type: 'image', 'file', 'voice', 'tool_result', 'platform_attachment', etc."""
mime_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
"""MIME type of the content."""
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
"""Original file name (if applicable)."""
size_bytes = sqlalchemy.Column(sqlalchemy.BigInteger, nullable=True)
"""Size in bytes."""
sha256 = sqlalchemy.Column(sqlalchemy.String(64), nullable=True)
"""SHA256 hash of content (for integrity verification)."""
source = sqlalchemy.Column(sqlalchemy.String(50), nullable=False)
"""Source of artifact: 'platform', 'runner', 'tool', 'system'."""
# Storage reference (points to BinaryStorage or external storage)
storage_key = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
"""Key in BinaryStorage or external storage reference."""
storage_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, default='binary_storage')
"""Storage type: 'binary_storage', 'file', 'url', etc."""
# Context
conversation_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
"""Conversation this artifact belongs to."""
run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
"""Run ID that created this artifact."""
runner_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
"""Runner ID that created this artifact."""
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
"""Bot UUID that handled this artifact."""
workspace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
"""Workspace ID for multi-tenant deployments."""
# Lifecycle
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow)
"""When this artifact was created."""
expires_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
"""When this artifact expires (optional)."""
metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
"""Additional metadata as JSON string."""

View File

@@ -17,6 +17,7 @@ from langbot.pkg.entity.persistence.base import Base
# This is required for autogenerate to detect model changes
from langbot.pkg.entity.persistence import (
apikey,
artifact,
bot,
bstorage,
event_log,

View File

@@ -0,0 +1,55 @@
"""add_agent_artifact_table
Revision ID: a1b2c3d4e5f6
Revises: 58846a8d7a81
Create Date: 2026-05-23 20:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = 'a1b2c3d4e5f6'
down_revision = '58846a8d7a81'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create agent_artifact table
op.create_table(
'agent_artifact',
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
sa.Column('artifact_id', sa.String(255), nullable=False, unique=True),
sa.Column('artifact_type', sa.String(50), nullable=False),
sa.Column('mime_type', sa.String(255), nullable=True),
sa.Column('name', sa.String(255), nullable=True),
sa.Column('size_bytes', sa.BigInteger(), nullable=True),
sa.Column('sha256', sa.String(64), nullable=True),
sa.Column('source', sa.String(50), nullable=False),
sa.Column('storage_key', sa.String(255), nullable=True),
sa.Column('storage_type', sa.String(50), nullable=False, server_default='binary_storage'),
sa.Column('conversation_id', sa.String(255), nullable=True),
sa.Column('run_id', sa.String(255), nullable=True),
sa.Column('runner_id', sa.String(255), nullable=True),
sa.Column('bot_id', sa.String(255), nullable=True),
sa.Column('workspace_id', sa.String(255), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')),
sa.Column('expires_at', sa.DateTime(), nullable=True),
sa.Column('metadata_json', sa.Text(), nullable=True),
)
# Create indexes for agent_artifact
with op.batch_alter_table('agent_artifact', schema=None) as batch_op:
batch_op.create_index('ix_agent_artifact_artifact_id', ['artifact_id'], unique=True)
batch_op.create_index('ix_agent_artifact_conversation_id', ['conversation_id'], unique=False)
batch_op.create_index('ix_agent_artifact_run_id', ['run_id'], unique=False)
def downgrade() -> None:
# Drop agent_artifact table
with op.batch_alter_table('agent_artifact', schema=None) as batch_op:
batch_op.drop_index('ix_agent_artifact_run_id')
batch_op.drop_index('ix_agent_artifact_conversation_id')
batch_op.drop_index('ix_agent_artifact_artifact_id')
op.drop_table('agent_artifact')

View File

@@ -100,6 +100,47 @@ def _build_tool_detail(tool: Any, requested_tool_name: str | None = None) -> dic
}
def _validate_artifact_access(
session: dict[str, Any],
artifact_metadata: dict[str, Any],
operation: str,
) -> tuple[bool, str | None]:
"""Validate artifact access for a run session.
Authorization rules (evaluated in order, first match wins):
1. Artifact run_id matches session run_id → ALLOW (created by this run)
2. Artifact has conversation_id AND matches session conversation_id → ALLOW (same conversation)
3. Otherwise → DENY
Note: Artifacts without conversation_id are NOT globally accessible by default.
Without an explicit scope field, we enforce strict access control.
Args:
session: AgentRunSession dict with run_id, conversation_id, permissions
artifact_metadata: Artifact metadata dict with conversation_id, run_id
operation: Operation name for error messages ('metadata' or 'read')
Returns:
Tuple of (is_allowed, error_message). If is_allowed is False, error_message contains reason.
"""
artifact_conversation_id = artifact_metadata.get('conversation_id')
artifact_run_id = artifact_metadata.get('run_id')
session_conversation_id = session.get('conversation_id')
session_run_id = session.get('run_id')
# Rule 1: Created by this run (allows cross-conversation access for self-created artifacts)
if artifact_run_id and artifact_run_id == session_run_id:
return True, None
# Rule 2: Same conversation (requires artifact to have conversation_id)
if artifact_conversation_id and session_conversation_id:
if artifact_conversation_id == session_conversation_id:
return True, None
# Rule 3: Deny - no matching authorization rule
return False, f'Artifact {operation} access denied: artifact not in session conversation and not created by this run'
def _normalize_uuid_list(values: Any) -> list[str]:
"""Normalize a user/config supplied UUID list while preserving order."""
if not isinstance(values, list):
@@ -1542,6 +1583,169 @@ class RuntimeConnectionHandler(handler.Handler):
self.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'Event page error: {e}')
# ================= Artifact APIs =================
@self.action(PluginToRuntimeAction.ARTIFACT_METADATA)
async def artifact_metadata(data: dict[str, Any]) -> handler.ActionResponse:
"""Get artifact metadata.
Requires run_id authorization. Only allows access to artifacts
in current run's conversation or created by current run.
"""
run_id = data.get('run_id')
artifact_id = data.get('artifact_id')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not artifact_id:
return handler.ActionResponse.error(message='artifact_id is required')
# Validate run session
session_registry = get_session_registry()
session = await session_registry.get(run_id)
if not session:
return handler.ActionResponse.error(
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity
if caller_plugin_identity:
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
# Check artifact permission from session.permissions (from descriptor.permissions)
permissions = session.get('permissions', {})
artifact_permissions = permissions.get('artifacts', [])
if 'metadata' not in artifact_permissions:
return handler.ActionResponse.error(
message='Artifact metadata access not authorized'
)
# Get artifact metadata
from ..agent.runner.artifact_store import ArtifactStore
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
try:
metadata = await store.get_metadata(artifact_id)
if not metadata:
return handler.ActionResponse.error(
message=f'Artifact {artifact_id} not found'
)
# Validate artifact access scope
is_allowed, error_msg = _validate_artifact_access(session, metadata, 'metadata')
if not is_allowed:
return handler.ActionResponse.error(message=error_msg)
return handler.ActionResponse.success(data=metadata)
except Exception as e:
self.ap.logger.error(f'ARTIFACT_METADATA error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'Artifact metadata error: {e}')
@self.action(PluginToRuntimeAction.ARTIFACT_READ)
async def artifact_read(data: dict[str, Any]) -> handler.ActionResponse:
"""Read artifact content.
Requires run_id authorization. Only allows access to artifacts
in current run's conversation or created by current run.
Supports range reads with offset/limit.
"""
run_id = data.get('run_id')
artifact_id = data.get('artifact_id')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if not artifact_id:
return handler.ActionResponse.error(message='artifact_id is required')
# Validate and parse offset
offset = data.get('offset', 0)
if not isinstance(offset, int):
try:
offset = int(offset)
except (TypeError, ValueError):
return handler.ActionResponse.error(message='offset must be an integer')
if offset < 0:
return handler.ActionResponse.error(message='offset must be >= 0')
# Validate and parse limit if provided
limit = data.get('limit')
if limit is not None:
if not isinstance(limit, int):
try:
limit = int(limit)
except (TypeError, ValueError):
return handler.ActionResponse.error(message='limit must be an integer')
if limit <= 0:
return handler.ActionResponse.error(message='limit must be > 0')
# Validate run session
session_registry = get_session_registry()
session = await session_registry.get(run_id)
if not session:
return handler.ActionResponse.error(
message=f'Run session {run_id} not found or expired'
)
# Validate caller plugin identity
if caller_plugin_identity:
session_plugin_identity = session.get('plugin_identity')
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
return handler.ActionResponse.error(
message=f'Plugin identity mismatch for run_id {run_id}'
)
# Check artifact permission from session.permissions (from descriptor.permissions)
permissions = session.get('permissions', {})
artifact_permissions = permissions.get('artifacts', [])
if 'read' not in artifact_permissions:
return handler.ActionResponse.error(
message='Artifact read access not authorized'
)
# Get artifact metadata first to validate access
from ..agent.runner.artifact_store import ArtifactStore
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
try:
metadata = await store.get_metadata(artifact_id)
if not metadata:
return handler.ActionResponse.error(
message=f'Artifact {artifact_id} not found'
)
# Validate artifact access scope
is_allowed, error_msg = _validate_artifact_access(session, metadata, 'read')
if not is_allowed:
return handler.ActionResponse.error(message=error_msg)
# Read artifact content (validates offset/limit internally)
result = await store.read_artifact(
artifact_id=artifact_id,
offset=offset,
limit=limit,
)
if not result:
return handler.ActionResponse.error(
message=f'Failed to read artifact {artifact_id}'
)
return handler.ActionResponse.success(data=result)
except ValueError as e:
# Offset/limit validation error
return handler.ActionResponse.error(message=str(e))
except Exception as e:
self.ap.logger.error(f'ARTIFACT_READ error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'Artifact read error: {e}')
@self.action(CommonAction.PING)
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
"""Ping"""

View File

@@ -0,0 +1,625 @@
"""Tests for ArtifactStore and artifact action handlers."""
from __future__ import annotations
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import base64
import datetime
import asyncio
from langbot.pkg.agent.runner.artifact_store import ArtifactStore
from langbot.pkg.agent.runner.session_registry import (
AgentRunSessionRegistry,
get_session_registry,
)
class TestArtifactStore:
"""Test ArtifactStore operations."""
def _make_mock_engine(self):
"""Create a mock database engine for AsyncSession-based store.
Note: The new store uses AsyncSession, so we need to mock
the session factory behavior.
"""
from unittest.mock import MagicMock, AsyncMock, patch
from sqlalchemy.ext.asyncio import AsyncEngine
engine = MagicMock(spec=AsyncEngine)
return engine
@pytest.mark.asyncio
async def test_register_artifact_generates_id(self):
"""Test register_artifact generates ID if not provided."""
engine = self._make_mock_engine()
store = ArtifactStore(engine)
# Mock the session factory
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
artifact_id = await store.register_artifact(
artifact_id=None,
artifact_type="image",
source="platform",
)
assert artifact_id is not None
assert len(artifact_id) == 36 # UUID format
@pytest.mark.asyncio
async def test_register_artifact_with_content(self):
"""Test register_artifact stores content in BinaryStorage."""
engine = self._make_mock_engine()
store = ArtifactStore(engine)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
content = b"test image content"
artifact_id = await store.register_artifact(
artifact_id="art_001",
artifact_type="image",
source="platform",
content=content,
)
assert artifact_id == "art_001"
@pytest.mark.asyncio
async def test_register_artifact_with_storage_key(self):
"""Test register_artifact with pre-existing storage_key."""
engine = self._make_mock_engine()
store = ArtifactStore(engine)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
artifact_id = await store.register_artifact(
artifact_id="art_002",
artifact_type="file",
source="runner",
storage_key="existing_key",
storage_type="binary_storage",
size_bytes=1024,
)
assert artifact_id == "art_002"
@pytest.mark.asyncio
async def test_get_metadata_not_found(self):
"""Test get_metadata returns None if not found."""
engine = self._make_mock_engine()
store = ArtifactStore(engine)
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
metadata = await store.get_metadata("nonexistent")
assert metadata is None
@pytest.mark.asyncio
async def test_read_artifact_validates_offset(self):
"""Test read_artifact rejects negative offset."""
engine = self._make_mock_engine()
store = ArtifactStore(engine)
with pytest.raises(ValueError, match="offset must be >= 0"):
await store.read_artifact("art_001", offset=-1)
@pytest.mark.asyncio
async def test_read_artifact_validates_limit(self):
"""Test read_artifact rejects zero or negative limit."""
engine = self._make_mock_engine()
store = ArtifactStore(engine)
with pytest.raises(ValueError, match="limit must be > 0"):
await store.read_artifact("art_001", limit=0)
with pytest.raises(ValueError, match="limit must be > 0"):
await store.read_artifact("art_001", limit=-5)
@pytest.mark.asyncio
async def test_read_artifact_not_found(self):
"""Test read_artifact returns None if not found."""
engine = self._make_mock_engine()
store = ArtifactStore(engine)
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
result = await store.read_artifact("nonexistent")
assert result is None
class TestArtifactAuthorization:
"""Test artifact action handler authorization."""
@pytest.fixture
def mock_session_registry(self):
"""Create a fresh session registry for testing."""
# Reset global registry
import langbot.pkg.agent.runner.session_registry as reg
reg._global_registry = None
return get_session_registry()
@pytest.fixture
def mock_handler(self):
"""Create a mock handler for testing actions."""
from langbot_plugin.runtime.io.handler import Handler
class MockHandler(Handler):
def __init__(self):
self._responses = {}
async def call_action(self, action, data, timeout=30):
# Simulate error response for missing run_id
if not data.get("run_id"):
return {"ok": False, "message": "run_id is required"}
return {"ok": True, "data": {}}
return MockHandler()
@pytest.mark.asyncio
async def test_artifact_metadata_requires_run_id(self, mock_handler):
"""Test artifact_metadata requires run_id."""
result = await mock_handler.call_action(
"artifact_metadata",
{"run_id": None, "artifact_id": "art_001"},
)
assert result.get("ok") is False or "error" in str(result).lower()
@pytest.mark.asyncio
async def test_artifact_read_requires_run_id(self, mock_handler):
"""Test artifact_read requires run_id."""
result = await mock_handler.call_action(
"artifact_read",
{"run_id": None, "artifact_id": "art_001"},
)
assert result.get("ok") is False or "error" in str(result).lower()
class TestArtifactAccessValidation:
"""Test _validate_artifact_access authorization rules."""
def _call_validate(self, session, metadata, operation="metadata"):
"""Helper to call the validation function."""
from langbot.pkg.plugin.handler import _validate_artifact_access
return _validate_artifact_access(session, metadata, operation)
def test_global_artifact_denied_by_default(self):
"""Artifacts without conversation_id are denied by default (no global access)."""
session = {
"run_id": "run_001",
"conversation_id": "conv_001",
"permissions": {"artifacts": ["metadata", "read"]},
}
metadata = {
"artifact_id": "art_global",
"conversation_id": None, # No conversation scope
"run_id": None, # Not created by any run
}
is_allowed, error = self._call_validate(session, metadata)
assert is_allowed is False
assert "denied" in error.lower()
def test_own_run_artifact_allowed(self):
"""Artifacts created by same run are allowed (even cross-conversation)."""
session = {
"run_id": "run_001",
"conversation_id": "conv_001",
"permissions": {"artifacts": ["metadata", "read"]},
}
metadata = {
"artifact_id": "art_001",
"conversation_id": "conv_other", # Different conversation
"run_id": "run_001", # Same run
}
is_allowed, error = self._call_validate(session, metadata)
assert is_allowed is True
assert error is None
def test_same_conversation_allowed(self):
"""Artifacts in same conversation are allowed."""
session = {
"run_id": "run_001",
"conversation_id": "conv_001",
"permissions": {"artifacts": ["metadata", "read"]},
}
metadata = {
"artifact_id": "art_001",
"conversation_id": "conv_001", # Same as session
"run_id": "run_other", # Different run
}
is_allowed, error = self._call_validate(session, metadata)
assert is_allowed is True
assert error is None
def test_different_conversation_and_run_denied(self):
"""Artifacts in different conversation and different run are denied."""
session = {
"run_id": "run_001",
"conversation_id": "conv_001",
"permissions": {"artifacts": ["metadata", "read"]},
}
metadata = {
"artifact_id": "art_001",
"conversation_id": "conv_other", # Different conversation
"run_id": "run_other", # Different run
}
is_allowed, error = self._call_validate(session, metadata)
assert is_allowed is False
assert "denied" in error.lower()
def test_session_without_conversation_denied_for_conversation_artifact(self):
"""Session without conversation_id cannot access conversation-scoped artifacts."""
session = {
"run_id": "run_001",
"conversation_id": None, # No conversation
"permissions": {"artifacts": ["metadata", "read"]},
}
metadata = {
"artifact_id": "art_001",
"conversation_id": "conv_001", # Has conversation
"run_id": "run_other", # Different run
}
is_allowed, error = self._call_validate(session, metadata)
assert is_allowed is False
def test_session_without_conversation_allowed_for_own_artifact(self):
"""Session without conversation can access artifacts it created."""
session = {
"run_id": "run_001",
"conversation_id": None, # No conversation
"permissions": {"artifacts": ["metadata", "read"]},
}
metadata = {
"artifact_id": "art_001",
"conversation_id": "conv_001", # Has conversation
"run_id": "run_001", # Same run (created by this run)
}
is_allowed, error = self._call_validate(session, metadata)
assert is_allowed is True
class TestContextAccessArtifactAPIs:
"""Test ContextAccess reflects artifact API permissions."""
@pytest.mark.asyncio
async def test_context_access_has_artifact_apis_when_permitted(self):
"""Test ContextAccess shows artifact APIs when permissions allow."""
# This tests the context builder logic
# When artifact permissions include 'metadata' and 'read',
# available_apis should reflect that
permissions = {"artifacts": ["metadata", "read"]}
# Check that permissions are properly interpreted
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
artifact_read_enabled = "read" in permissions.get("artifacts", [])
assert artifact_metadata_enabled is True
assert artifact_read_enabled is True
@pytest.mark.asyncio
async def test_context_access_no_artifact_apis_without_permission(self):
"""Test ContextAccess hides artifact APIs when permissions denied."""
permissions = {"artifacts": []}
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
artifact_read_enabled = "read" in permissions.get("artifacts", [])
assert artifact_metadata_enabled is False
assert artifact_read_enabled is False
class TestArtifactMetadataFieldAlignment:
"""Test that Host returns metadata compatible with SDK ArtifactMetadata."""
def test_row_to_public_dict_excludes_host_only_fields(self):
"""_row_to_public_dict should not return Host-only fields."""
from langbot.pkg.agent.runner.artifact_store import ArtifactStore
from langbot.pkg.entity.persistence.artifact import AgentArtifact
from unittest.mock import MagicMock
# Create a mock row
mock_row = MagicMock(spec=AgentArtifact)
mock_row.artifact_id = "art_001"
mock_row.artifact_type = "image"
mock_row.mime_type = "image/png"
mock_row.name = "test.png"
mock_row.size_bytes = 1024
mock_row.sha256 = "abc123"
mock_row.source = "platform"
mock_row.conversation_id = "conv_001"
mock_row.run_id = "run_001"
mock_row.runner_id = "plugin:test/plugin/runner"
mock_row.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0)
mock_row.expires_at = None
mock_row.metadata_json = None
# These are Host-only fields that should NOT be in output
# (they don't exist in SDK ArtifactMetadata)
mock_row.bot_id = "bot_001"
mock_row.workspace_id = "ws_001"
mock_row.storage_key = "artifact:art_001"
mock_row.storage_type = "binary_storage"
store = ArtifactStore(MagicMock())
result = store._row_to_public_dict(mock_row)
# SDK-compatible fields should be present
assert result["artifact_id"] == "art_001"
assert result["artifact_type"] == "image"
assert result["source"] == "platform"
assert result["conversation_id"] == "conv_001"
assert result["run_id"] == "run_001"
# Host-only fields should NOT be present
assert "bot_id" not in result
assert "workspace_id" not in result
assert "storage_key" not in result
assert "storage_type" not in result
class TestSessionRegistryPermissions:
"""Test that session registry stores and retrieves permissions correctly."""
@pytest.fixture
def session_registry(self):
"""Create a fresh session registry for testing."""
import langbot.pkg.agent.runner.session_registry as reg
reg._global_registry = None
return get_session_registry()
@pytest.mark.asyncio
async def test_register_stores_permissions(self, session_registry):
"""Test that register() stores permissions from descriptor."""
await session_registry.register(
run_id="run_001",
runner_id="plugin:author/plugin/runner",
query_id=None,
plugin_identity="author/plugin",
resources={
"models": [],
"tools": [],
"knowledge_bases": [],
"files": [],
"storage": {"plugin_storage": True, "workspace_storage": False},
"platform_capabilities": {},
},
permissions={
"artifacts": ["metadata", "read"],
"history": ["page"],
"events": ["get"],
},
conversation_id="conv_001",
)
session = await session_registry.get("run_001")
assert session is not None
assert session["permissions"]["artifacts"] == ["metadata", "read"]
assert session["permissions"]["history"] == ["page"]
assert session["permissions"]["events"] == ["get"]
@pytest.mark.asyncio
async def test_register_with_empty_permissions(self, session_registry):
"""Test that register() handles empty permissions."""
await session_registry.register(
run_id="run_002",
runner_id="plugin:author/plugin/runner",
query_id=None,
plugin_identity="author/plugin",
resources={
"models": [],
"tools": [],
"knowledge_bases": [],
"files": [],
"storage": {"plugin_storage": True, "workspace_storage": False},
"platform_capabilities": {},
},
permissions={},
conversation_id="conv_001",
)
session = await session_registry.get("run_002")
assert session is not None
assert session["permissions"] == {}
class TestArtifactStoreRealSQLite:
"""Test ArtifactStore with real SQLite database."""
@pytest.fixture
async def db_engine(self):
"""Create an in-memory SQLite database for testing."""
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import text
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.entity.persistence.artifact import AgentArtifact
from langbot.pkg.entity.persistence.bstorage import BinaryStorage
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
# Create tables
async with engine.begin() as conn:
# Create tables manually for in-memory DB
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.mark.asyncio
async def test_register_get_metadata_round_trip(self, db_engine):
"""Test register_artifact -> get_metadata round trip with real DB."""
store = ArtifactStore(db_engine)
# Register artifact with content
content = b"test image content for round trip"
artifact_id = await store.register_artifact(
artifact_id="art_real_001",
artifact_type="image",
source="platform",
mime_type="image/png",
name="test.png",
content=content,
conversation_id="conv_001",
run_id="run_001",
)
assert artifact_id == "art_real_001"
# Get metadata
metadata = await store.get_metadata(artifact_id)
assert metadata is not None
assert metadata["artifact_id"] == "art_real_001"
assert metadata["artifact_type"] == "image"
assert metadata["mime_type"] == "image/png"
assert metadata["source"] == "platform"
assert metadata["conversation_id"] == "conv_001"
assert metadata["run_id"] == "run_001"
# Verify Host-only fields are NOT in public metadata
assert "storage_key" not in metadata
assert "storage_type" not in metadata
assert "bot_id" not in metadata
assert "workspace_id" not in metadata
@pytest.mark.asyncio
async def test_read_artifact_round_trip(self, db_engine):
"""Test register_artifact -> read_artifact round trip with real DB."""
store = ArtifactStore(db_engine)
# Register artifact with content
content = b"test file content for read test"
artifact_id = await store.register_artifact(
artifact_id="art_real_002",
artifact_type="file",
source="runner",
mime_type="text/plain",
name="test.txt",
content=content,
conversation_id="conv_001",
run_id="run_001",
)
# Read artifact
result = await store.read_artifact(artifact_id)
assert result is not None
assert result["artifact_id"] == "art_real_002"
assert result["mime_type"] == "text/plain"
assert result["offset"] == 0
assert result["length"] == len(content)
assert result["has_more"] is False
# Verify content
decoded_content = base64.b64decode(result["content_base64"])
assert decoded_content == content
@pytest.mark.asyncio
async def test_read_artifact_with_offset_limit(self, db_engine):
"""Test read_artifact with offset and limit."""
store = ArtifactStore(db_engine)
# Register artifact with content
content = b"0123456789" * 100 # 1000 bytes
artifact_id = await store.register_artifact(
artifact_id="art_real_003",
artifact_type="file",
source="runner",
mime_type="application/octet-stream",
content=content,
)
# Read with offset
result = await store.read_artifact(artifact_id, offset=100, limit=100)
assert result is not None
assert result["offset"] == 100
assert result["length"] == 100
# Verify content
decoded_content = base64.b64decode(result["content_base64"])
assert decoded_content == content[100:200]
@pytest.mark.asyncio
async def test_read_artifact_has_more(self, db_engine):
"""Test read_artifact sets has_more correctly."""
store = ArtifactStore(db_engine)
# Register artifact with content
content = b"0123456789" * 100 # 1000 bytes
artifact_id = await store.register_artifact(
artifact_id="art_real_004",
artifact_type="file",
source="runner",
content=content,
)
# Read with limit smaller than content
result = await store.read_artifact(artifact_id, offset=0, limit=100)
assert result is not None
assert result["has_more"] is True
assert result["length"] == 100
@pytest.mark.asyncio
async def test_metadata_sdk_validation(self, db_engine):
"""Test that metadata can be validated by SDK ArtifactMetadata."""
from langbot_plugin.api.entities.builtin.agent_runner.artifact import ArtifactMetadata
store = ArtifactStore(db_engine)
# Register artifact
artifact_id = await store.register_artifact(
artifact_id="art_real_005",
artifact_type="file",
source="runner",
mime_type="application/pdf",
name="document.pdf",
size_bytes=1024,
conversation_id="conv_001",
run_id="run_001",
runner_id="plugin:test/plugin/runner",
)
# Get metadata
metadata = await store.get_metadata(artifact_id)
assert metadata is not None
# Should not raise ValidationError
validated = ArtifactMetadata.model_validate(metadata)
assert validated.artifact_id == "art_real_005"
assert validated.artifact_type == "file"

View File

@@ -73,49 +73,78 @@ class TestEventLogStore:
@pytest.mark.asyncio
async def test_append_event(self, mock_db_engine):
"""Test appending an event to EventLog."""
from unittest.mock import AsyncMock, MagicMock, patch
store = EventLogStore(mock_db_engine)
event_id = await store.append_event(
event_id="evt_1",
event_type="message.received",
source="platform",
bot_id="bot_1",
conversation_id="conv_1",
actor_type="user",
actor_id="user_1",
input_summary="Hello world",
run_id="run_1",
runner_id="plugin:test/plugin/runner",
)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
assert event_id == "evt_1"
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
event_id = await store.append_event(
event_id="evt_1",
event_type="message.received",
source="platform",
bot_id="bot_1",
conversation_id="conv_1",
actor_type="user",
actor_id="user_1",
input_summary="Hello world",
run_id="run_1",
runner_id="plugin:test/plugin/runner",
)
assert event_id == "evt_1"
@pytest.mark.asyncio
async def test_append_event_truncates_input_summary(self, mock_db_engine):
"""Test that long input summaries are truncated."""
from unittest.mock import AsyncMock, MagicMock, patch
store = EventLogStore(mock_db_engine)
long_text = "x" * 2000
event_id = await store.append_event(
event_id="evt_2",
event_type="message.received",
source="platform",
input_summary=long_text,
)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
assert event_id == "evt_2"
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
long_text = "x" * 2000
event_id = await store.append_event(
event_id="evt_2",
event_type="message.received",
source="platform",
input_summary=long_text,
)
assert event_id == "evt_2"
@pytest.mark.asyncio
async def test_page_events_with_conversation_filter(self, mock_db_engine):
"""Test paging events with conversation_id filter."""
from unittest.mock import AsyncMock, MagicMock, patch
store = EventLogStore(mock_db_engine)
items, next_seq, has_more = await store.page_events(
conversation_id="conv_1",
limit=10,
)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
assert isinstance(items, list)
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
items, next_seq, has_more = await store.page_events(
conversation_id="conv_1",
limit=10,
)
assert isinstance(items, list)
class TestTranscriptStore:
@@ -124,75 +153,129 @@ class TestTranscriptStore:
@pytest.mark.asyncio
async def test_append_transcript(self, mock_db_engine):
"""Test appending a transcript item."""
from unittest.mock import AsyncMock, MagicMock, patch
store = TranscriptStore(mock_db_engine)
transcript_id = await store.append_transcript(
transcript_id=None, # Auto-generate
event_id="evt_1",
conversation_id="conv_1",
role="user",
content="Hello",
)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
assert transcript_id is not None
# Mock _get_next_seq
with patch.object(store, '_get_next_seq', return_value=1):
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
transcript_id = await store.append_transcript(
transcript_id=None, # Auto-generate
event_id="evt_1",
conversation_id="conv_1",
role="user",
content="Hello",
)
assert transcript_id is not None
@pytest.mark.asyncio
async def test_append_transcript_with_artifacts(self, mock_db_engine):
"""Test appending transcript with artifact refs."""
from unittest.mock import AsyncMock, MagicMock, patch
store = TranscriptStore(mock_db_engine)
transcript_id = await store.append_transcript(
transcript_id=None, # Auto-generate
event_id="evt_2",
conversation_id="conv_1",
role="assistant",
content="Here's an image",
artifact_refs=[
{"artifact_id": "art_1", "artifact_type": "image", "url": "http://example.com/img.png"}
],
)
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
assert transcript_id is not None
with patch.object(store, '_get_next_seq', return_value=1):
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
transcript_id = await store.append_transcript(
transcript_id=None, # Auto-generate
event_id="evt_2",
conversation_id="conv_1",
role="assistant",
content="Here's an image",
artifact_refs=[
{"artifact_id": "art_1", "artifact_type": "image", "url": "http://example.com/img.png"}
],
)
assert transcript_id is not None
@pytest.mark.asyncio
async def test_page_transcript_backward(self, mock_db_engine):
"""Test paging transcript backward (older items)."""
from unittest.mock import AsyncMock, MagicMock, patch
store = TranscriptStore(mock_db_engine)
items, next_seq, prev_seq, has_more = await store.page_transcript(
conversation_id="conv_1",
limit=10,
direction="backward",
)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
assert isinstance(items, list)
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
items, next_seq, prev_seq, has_more = await store.page_transcript(
conversation_id="conv_1",
limit=10,
direction="backward",
)
assert isinstance(items, list)
@pytest.mark.asyncio
async def test_page_transcript_has_hard_limit(self, mock_db_engine):
"""Test that transcript paging has a hard limit."""
from unittest.mock import AsyncMock, MagicMock, patch
store = TranscriptStore(mock_db_engine)
# Request more than the hard limit
items, next_seq, prev_seq, has_more = await store.page_transcript(
conversation_id="conv_1",
limit=200, # Request 200, but hard limit is 100
)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
# The store should cap at 100
assert len(items) <= store.HARD_LIMIT
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
# Request more than the hard limit
items, next_seq, prev_seq, has_more = await store.page_transcript(
conversation_id="conv_1",
limit=200, # Request 200, but hard limit is 100
)
# The store should cap at 100
assert len(items) <= store.HARD_LIMIT
@pytest.mark.asyncio
async def test_search_transcript(self, mock_db_engine):
"""Test searching transcript."""
from unittest.mock import AsyncMock, MagicMock, patch
store = TranscriptStore(mock_db_engine)
items = await store.search_transcript(
conversation_id="conv_1",
query_text="database",
top_k=10,
)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
assert isinstance(items, list)
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
items = await store.search_transcript(
conversation_id="conv_1",
query_text="database",
top_k=10,
)
assert isinstance(items, list)
class TestHistoryPageAuthorization:
@@ -259,50 +342,244 @@ class TestContextAccessPopulation:
@pytest.mark.asyncio
async def test_context_access_has_history_apis_when_permitted(self, mock_db_engine):
"""Test ContextAccess shows available APIs based on permissions."""
# This would test the context builder logic
# For now we verify the store methods work
from unittest.mock import AsyncMock, MagicMock, patch
store = TranscriptStore(mock_db_engine)
cursor = await store.get_latest_cursor("conv_1")
# Should return None or a cursor string
assert cursor is None or isinstance(cursor, str)
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
cursor = await store.get_latest_cursor("conv_1")
# Should return None or a cursor string
assert cursor is None or isinstance(cursor, str)
@pytest.mark.asyncio
async def test_context_access_shows_has_history_before(self, mock_db_engine):
"""Test ContextAccess indicates if history exists."""
from unittest.mock import AsyncMock, MagicMock, patch
store = TranscriptStore(mock_db_engine)
has_history = await store.has_history_before("conv_1", 10)
assert isinstance(has_history, bool)
mock_result = MagicMock()
mock_result.scalar.return_value = 0
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
with patch.object(store, '_session_factory') as mock_factory:
mock_factory.return_value.__aenter__.return_value = mock_session
has_history = await store.has_history_before("conv_1", 10)
assert isinstance(has_history, bool)
class TestEventLogStoreRealSQLite:
"""Test EventLogStore with real SQLite database."""
@pytest.fixture
async def db_engine(self):
"""Create an in-memory SQLite database for testing."""
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import text
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.entity.persistence.event_log import EventLog
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.mark.asyncio
async def test_append_get_event_round_trip(self, db_engine):
"""Test append_event -> get_event round trip with real DB."""
store = EventLogStore(db_engine)
# Append event
event_id = await store.append_event(
event_id="evt_real_001",
event_type="message.received",
source="platform",
bot_id="bot_001",
conversation_id="conv_001",
actor_type="user",
actor_id="user_001",
actor_name="Test User",
input_summary="Hello world",
run_id="run_001",
runner_id="plugin:test/plugin/runner",
)
assert event_id == "evt_real_001"
# Get event
event = await store.get_event(event_id)
assert event is not None
assert event["event_id"] == "evt_real_001"
assert event["event_type"] == "message.received"
assert event["source"] == "platform"
assert event["conversation_id"] == "conv_001"
assert event["actor_type"] == "user"
assert event["actor_id"] == "user_001"
@pytest.mark.asyncio
async def test_page_events(self, db_engine):
"""Test page_events with real DB."""
store = EventLogStore(db_engine)
# Append multiple events
for i in range(5):
await store.append_event(
event_id=f"evt_real_{i:03d}",
event_type="message.received",
source="platform",
conversation_id="conv_001",
input_summary=f"Message {i}",
)
# Page events
items, next_seq, has_more = await store.page_events(
conversation_id="conv_001",
limit=3,
)
assert len(items) == 3
assert has_more is True
@pytest.mark.asyncio
async def test_get_latest_cursor(self, db_engine):
"""Test get_latest_cursor with real DB."""
store = EventLogStore(db_engine)
# Append events
for i in range(3):
await store.append_event(
event_id=f"evt_cursor_{i:03d}",
event_type="message.received",
source="platform",
conversation_id="conv_cursor",
)
# Get latest cursor
cursor = await store.get_latest_cursor("conv_cursor")
assert cursor is not None
assert int(cursor) > 0
class TestTranscriptStoreRealSQLite:
"""Test TranscriptStore with real SQLite database."""
@pytest.fixture
async def db_engine(self):
"""Create an in-memory SQLite database for testing."""
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import text
from langbot.pkg.entity.persistence.base import Base
from langbot.pkg.entity.persistence.transcript import Transcript
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.mark.asyncio
async def test_append_page_transcript_round_trip(self, db_engine):
"""Test append_transcript -> page_transcript round trip with real DB."""
store = TranscriptStore(db_engine)
# Append transcript items
for i in range(3):
await store.append_transcript(
transcript_id=f"trans_real_{i:03d}",
event_id=f"evt_{i:03d}",
conversation_id="conv_001",
role="user" if i % 2 == 0 else "assistant",
content=f"Message {i}",
)
# Page transcript
items, next_seq, prev_seq, has_more = await store.page_transcript(
conversation_id="conv_001",
limit=10,
)
assert len(items) == 3
assert items[0]["conversation_id"] == "conv_001"
@pytest.mark.asyncio
async def test_search_transcript_real_db(self, db_engine):
"""Test search_transcript with real DB."""
store = TranscriptStore(db_engine)
# Append transcript items
await store.append_transcript(
transcript_id="trans_search_001",
event_id="evt_search_001",
conversation_id="conv_search",
role="user",
content="I want to learn about databases",
)
await store.append_transcript(
transcript_id="trans_search_002",
event_id="evt_search_002",
conversation_id="conv_search",
role="assistant",
content="Here is information about databases",
)
# Search for "database"
items = await store.search_transcript(
conversation_id="conv_search",
query_text="database",
)
# Should find at least one match
assert len(items) >= 1
@pytest.mark.asyncio
async def test_get_latest_cursor_real_db(self, db_engine):
"""Test get_latest_cursor with real DB."""
store = TranscriptStore(db_engine)
# Append transcript items
for i in range(3):
await store.append_transcript(
transcript_id=f"trans_cursor_{i:03d}",
event_id=f"evt_cursor_{i:03d}",
conversation_id="conv_cursor",
role="user",
content=f"Message {i}",
)
# Get latest cursor
cursor = await store.get_latest_cursor("conv_cursor")
assert cursor is not None
assert int(cursor) > 0
# Fixtures
@pytest.fixture
def mock_db_engine():
"""Create a mock database engine."""
from unittest.mock import MagicMock, AsyncMock
"""Create a mock database engine for AsyncSession-based stores."""
from unittest.mock import MagicMock
from sqlalchemy.ext.asyncio import AsyncEngine
engine = MagicMock(spec=AsyncEngine)
# Mock connection
mock_conn = MagicMock()
mock_result = MagicMock()
mock_result.fetchone.return_value = None
mock_result.fetchall.return_value = []
mock_result.scalar.return_value = 0
mock_conn.execute = AsyncMock(return_value=mock_result)
mock_conn.commit = AsyncMock()
# Create async context manager for connect()
class AsyncConnectContextManager:
async def __aenter__(self):
return mock_conn
async def __aexit__(self, *args):
pass
# connect() should return an async context manager
engine.connect = MagicMock(return_value=AsyncConnectContextManager())
return engine