mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat(agent-runner): add artifact store pull APIs
This commit is contained in:
300
src/langbot/pkg/agent/runner/artifact_store.py
Normal file
300
src/langbot/pkg/agent/runner/artifact_store.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Artifact store for managing Host-owned artifacts."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import datetime
|
||||
import typing
|
||||
import uuid
|
||||
import base64
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from ...entity.persistence.artifact import AgentArtifact
|
||||
from ...entity.persistence.bstorage import BinaryStorage
|
||||
|
||||
|
||||
class ArtifactStore:
|
||||
"""Store for AgentArtifact records.
|
||||
|
||||
Handles artifact metadata registration and content retrieval.
|
||||
Actual blob storage is delegated to BinaryStorage or external storage.
|
||||
|
||||
All methods are async and use the provided database engine.
|
||||
"""
|
||||
|
||||
engine: AsyncEngine
|
||||
|
||||
# Hard limits
|
||||
MAX_INLINE_READ_BYTES = 1024 * 1024 # 1MB max for inline base64
|
||||
MAX_RANGE_READ_BYTES = 10 * 1024 * 1024 # 10MB max for range reads
|
||||
|
||||
def __init__(self, engine: AsyncEngine):
|
||||
self.engine = engine
|
||||
self._session_factory = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def register_artifact(
|
||||
self,
|
||||
artifact_id: str | None,
|
||||
artifact_type: str,
|
||||
source: str,
|
||||
storage_key: str | None = None,
|
||||
storage_type: str = 'binary_storage',
|
||||
mime_type: str | None = None,
|
||||
name: str | None = None,
|
||||
size_bytes: int | None = None,
|
||||
sha256: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
run_id: str | None = None,
|
||||
runner_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
expires_at: datetime.datetime | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
content: bytes | None = None,
|
||||
) -> str:
|
||||
"""Register a new artifact.
|
||||
|
||||
If content is provided and storage_key is None, stores content
|
||||
in BinaryStorage automatically.
|
||||
|
||||
Args:
|
||||
artifact_id: Unique artifact ID (generated if None)
|
||||
artifact_type: Type of artifact (image, file, voice, tool_result, etc.)
|
||||
source: Source of artifact (platform, runner, tool, system)
|
||||
storage_key: Key in BinaryStorage or external reference
|
||||
storage_type: Storage type (binary_storage, file, url)
|
||||
mime_type: MIME type
|
||||
name: Original file name
|
||||
size_bytes: Size in bytes
|
||||
sha256: SHA256 hash
|
||||
conversation_id: Conversation ID
|
||||
run_id: Run ID that created this
|
||||
runner_id: Runner ID that created this
|
||||
bot_id: Bot UUID
|
||||
workspace_id: Workspace ID
|
||||
expires_at: Expiration time
|
||||
metadata: Additional metadata
|
||||
content: Optional content to store in BinaryStorage
|
||||
|
||||
Returns:
|
||||
The artifact_id
|
||||
"""
|
||||
if artifact_id is None:
|
||||
artifact_id = str(uuid.uuid4())
|
||||
|
||||
# If content provided, store in BinaryStorage
|
||||
if content is not None and storage_key is None:
|
||||
storage_key = f"artifact:{artifact_id}"
|
||||
storage_type = 'binary_storage'
|
||||
if size_bytes is None:
|
||||
size_bytes = len(content)
|
||||
|
||||
async with self._session_factory() as session:
|
||||
# Store content in BinaryStorage if provided
|
||||
if content is not None:
|
||||
binary_storage = BinaryStorage(
|
||||
unique_key=f'artifact:{artifact_id}',
|
||||
key=storage_key,
|
||||
owner_type='artifact',
|
||||
owner='host',
|
||||
value=content,
|
||||
)
|
||||
session.add(binary_storage)
|
||||
|
||||
# Store artifact metadata
|
||||
artifact = AgentArtifact(
|
||||
artifact_id=artifact_id,
|
||||
artifact_type=artifact_type,
|
||||
mime_type=mime_type,
|
||||
name=name,
|
||||
size_bytes=size_bytes,
|
||||
sha256=sha256,
|
||||
source=source,
|
||||
storage_key=storage_key,
|
||||
storage_type=storage_type,
|
||||
conversation_id=conversation_id,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
expires_at=expires_at,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
session.add(artifact)
|
||||
await session.commit()
|
||||
|
||||
return artifact_id
|
||||
|
||||
async def get_metadata(
|
||||
self,
|
||||
artifact_id: str,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Get artifact metadata (public fields only, no internal storage info).
|
||||
|
||||
Args:
|
||||
artifact_id: Artifact ID
|
||||
|
||||
Returns:
|
||||
Artifact metadata dict compatible with SDK ArtifactMetadata, or None if not found
|
||||
"""
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(AgentArtifact).where(
|
||||
AgentArtifact.artifact_id == artifact_id
|
||||
)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_public_dict(row)
|
||||
|
||||
async def _get_internal_record(
|
||||
self,
|
||||
artifact_id: str,
|
||||
) -> AgentArtifact | None:
|
||||
"""Get full artifact record including internal fields.
|
||||
|
||||
Used internally by read_artifact to access storage_key/storage_type.
|
||||
|
||||
Args:
|
||||
artifact_id: Artifact ID
|
||||
|
||||
Returns:
|
||||
AgentArtifact ORM instance, or None if not found
|
||||
"""
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(AgentArtifact).where(
|
||||
AgentArtifact.artifact_id == artifact_id
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
artifact_id: str,
|
||||
offset: int = 0,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Read artifact content.
|
||||
|
||||
For small artifacts, returns content_base64 directly.
|
||||
For large artifacts, returns file_key for chunked transfer.
|
||||
|
||||
Args:
|
||||
artifact_id: Artifact ID
|
||||
offset: Byte offset to start reading from (must be >= 0)
|
||||
limit: Maximum bytes to read (must be > 0 if provided)
|
||||
|
||||
Returns:
|
||||
ArtifactReadResult dict, or None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If offset < 0 or limit <= 0
|
||||
"""
|
||||
# Validate offset and limit
|
||||
if offset < 0:
|
||||
raise ValueError("offset must be >= 0")
|
||||
|
||||
if limit is not None and limit <= 0:
|
||||
raise ValueError("limit must be > 0")
|
||||
|
||||
# Get internal record (includes storage_key/storage_type)
|
||||
record = await self._get_internal_record(artifact_id)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
storage_type = record.storage_type or 'binary_storage'
|
||||
storage_key = record.storage_key
|
||||
size_bytes = record.size_bytes or 0
|
||||
|
||||
# Cap limit at hard limit
|
||||
if limit is None:
|
||||
limit = self.MAX_INLINE_READ_BYTES
|
||||
limit = min(limit, self.MAX_RANGE_READ_BYTES)
|
||||
|
||||
# For binary_storage, read content
|
||||
if storage_type == 'binary_storage' and storage_key:
|
||||
content = await self._read_binary_storage(storage_key)
|
||||
if content is None:
|
||||
return None
|
||||
|
||||
# Apply offset and limit
|
||||
if offset > 0:
|
||||
content = content[offset:]
|
||||
if limit and len(content) > limit:
|
||||
content = content[:limit]
|
||||
has_more = True
|
||||
else:
|
||||
has_more = False
|
||||
|
||||
return {
|
||||
'artifact_id': artifact_id,
|
||||
'mime_type': record.mime_type,
|
||||
'size_bytes': size_bytes,
|
||||
'offset': offset,
|
||||
'length': len(content),
|
||||
'content_base64': base64.b64encode(content).decode('utf-8'),
|
||||
'file_key': None,
|
||||
'has_more': has_more,
|
||||
}
|
||||
|
||||
# For other storage types, return storage reference
|
||||
# (caller can use file_key for chunked transfer)
|
||||
return {
|
||||
'artifact_id': artifact_id,
|
||||
'mime_type': record.mime_type,
|
||||
'size_bytes': size_bytes,
|
||||
'offset': offset,
|
||||
'length': None,
|
||||
'content_base64': None,
|
||||
'file_key': storage_key,
|
||||
'has_more': False,
|
||||
}
|
||||
|
||||
async def _read_binary_storage(self, key: str) -> bytes | None:
|
||||
"""Read content from BinaryStorage.
|
||||
|
||||
Uses unique_key for isolation to prevent cross-artifact access.
|
||||
|
||||
Args:
|
||||
key: The unique_key used when storing the artifact
|
||||
|
||||
Returns:
|
||||
Content bytes, or None if not found
|
||||
"""
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(BinaryStorage).where(BinaryStorage.unique_key == key)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return row.value
|
||||
|
||||
def _row_to_public_dict(self, row: AgentArtifact) -> dict[str, typing.Any]:
|
||||
"""Convert an AgentArtifact row to public dict.
|
||||
|
||||
Returns only fields that match SDK ArtifactMetadata entity.
|
||||
Host-only fields (bot_id, workspace_id, storage_key, storage_type) are excluded.
|
||||
"""
|
||||
return {
|
||||
'artifact_id': row.artifact_id,
|
||||
'artifact_type': row.artifact_type,
|
||||
'mime_type': row.mime_type,
|
||||
'name': row.name,
|
||||
'size_bytes': row.size_bytes,
|
||||
'sha256': row.sha256,
|
||||
'source': row.source,
|
||||
'conversation_id': row.conversation_id,
|
||||
'run_id': row.run_id,
|
||||
'runner_id': row.runner_id,
|
||||
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
|
||||
'expires_at': int(row.expires_at.timestamp()) if row.expires_at else None,
|
||||
'metadata': json.loads(row.metadata_json) if row.metadata_json else {},
|
||||
}
|
||||
@@ -891,11 +891,14 @@ class AgentRunContextBuilder:
|
||||
permissions = descriptor.permissions or {}
|
||||
history_permissions = permissions.get('history', [])
|
||||
event_permissions = permissions.get('events', [])
|
||||
artifact_permissions = permissions.get('artifacts', [])
|
||||
|
||||
history_page_enabled = 'page' in history_permissions and conversation_id is not None
|
||||
history_search_enabled = 'search' in history_permissions and conversation_id is not None
|
||||
event_get_enabled = 'get' in event_permissions
|
||||
event_page_enabled = 'page' in event_permissions and conversation_id is not None
|
||||
artifact_metadata_enabled = 'metadata' in artifact_permissions
|
||||
artifact_read_enabled = 'read' in artifact_permissions
|
||||
|
||||
# Get latest cursor and has_history_before if conversation exists
|
||||
latest_cursor = None
|
||||
@@ -931,8 +934,8 @@ class AgentRunContextBuilder:
|
||||
'history_search': history_search_enabled,
|
||||
'event_get': event_get_enabled,
|
||||
'event_page': event_page_enabled,
|
||||
'artifact_metadata': False, # TODO: Implement artifact store
|
||||
'artifact_read': False,
|
||||
'artifact_metadata': artifact_metadata_enabled,
|
||||
'artifact_read': artifact_read_enabled,
|
||||
'state': True,
|
||||
'storage': True,
|
||||
},
|
||||
|
||||
@@ -7,7 +7,8 @@ import typing
|
||||
import uuid
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from ...entity.persistence.event_log import EventLog
|
||||
from ...entity.persistence.transcript import Transcript
|
||||
@@ -27,6 +28,9 @@ class EventLogStore:
|
||||
|
||||
def __init__(self, engine: AsyncEngine):
|
||||
self.engine = engine
|
||||
self._session_factory = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def append_event(
|
||||
self,
|
||||
@@ -83,9 +87,8 @@ class EventLogStore:
|
||||
if input_summary and len(input_summary) > self.MAX_INPUT_SUMMARY_LENGTH:
|
||||
input_summary = input_summary[:self.MAX_INPUT_SUMMARY_LENGTH - 3] + "..."
|
||||
|
||||
async with self.engine.connect() as conn:
|
||||
await conn.execute(
|
||||
sqlalchemy.insert(EventLog).values(
|
||||
async with self._session_factory() as session:
|
||||
event = EventLog(
|
||||
event_id=event_id,
|
||||
event_type=event_type,
|
||||
event_time=event_time,
|
||||
@@ -107,8 +110,8 @@ class EventLogStore:
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
await conn.commit()
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
|
||||
return event_id
|
||||
|
||||
@@ -124,14 +127,14 @@ class EventLogStore:
|
||||
Returns:
|
||||
Event record as dict, or None if not found
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(EventLog).where(EventLog.event_id == event_id)
|
||||
)
|
||||
row = result.fetchone()
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_dict(row[0])
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def page_events(
|
||||
self,
|
||||
@@ -153,7 +156,7 @@ class EventLogStore:
|
||||
"""
|
||||
limit = min(limit, 100) # Hard cap
|
||||
|
||||
async with self.engine.connect() as conn:
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(EventLog)
|
||||
|
||||
if conversation_id is not None:
|
||||
@@ -167,10 +170,10 @@ class EventLogStore:
|
||||
|
||||
query = query.order_by(EventLog.id.desc()).limit(limit + 1)
|
||||
|
||||
result = await conn.execute(query)
|
||||
rows = result.fetchall()
|
||||
result = await session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
items = [self._row_to_dict(row[0]) for row in rows[:limit]]
|
||||
items = [self._row_to_dict(row) for row in rows[:limit]]
|
||||
has_more = len(rows) > limit
|
||||
next_seq = items[-1]['id'] if items and has_more else None
|
||||
|
||||
@@ -188,17 +191,17 @@ class EventLogStore:
|
||||
Returns:
|
||||
Cursor string (seq number), or None if no events
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(EventLog.id)
|
||||
.where(EventLog.conversation_id == conversation_id)
|
||||
.order_by(EventLog.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.fetchone()
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return str(row[0])
|
||||
return str(row)
|
||||
|
||||
async def has_events_before(
|
||||
self,
|
||||
@@ -214,8 +217,8 @@ class EventLogStore:
|
||||
Returns:
|
||||
True if there are events before
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(sqlalchemy.func.count())
|
||||
.select_from(EventLog)
|
||||
.where(
|
||||
|
||||
@@ -125,6 +125,7 @@ class AgentRunOrchestrator:
|
||||
query_id=None, # No query_id in event-first mode
|
||||
plugin_identity=descriptor.get_plugin_id(),
|
||||
resources=resources,
|
||||
permissions=descriptor.permissions or {},
|
||||
conversation_id=event.conversation_id,
|
||||
)
|
||||
|
||||
@@ -222,6 +223,7 @@ class AgentRunOrchestrator:
|
||||
query_id=query.query_id,
|
||||
plugin_identity=descriptor.get_plugin_id(),
|
||||
resources=resources,
|
||||
permissions=descriptor.permissions or {},
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class AgentRunSession(typing.TypedDict):
|
||||
plugin_identity: Plugin identifier (author/name) of the runner
|
||||
conversation_id: Conversation ID for history/event access
|
||||
resources: Authorized resources for this run (from AgentResources)
|
||||
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
|
||||
status: Session status tracking
|
||||
_authorized_ids: Pre-computed authorized resource IDs for O(1) lookup
|
||||
"""
|
||||
@@ -36,6 +37,7 @@ class AgentRunSession(typing.TypedDict):
|
||||
plugin_identity: str # author/name
|
||||
conversation_id: str | None
|
||||
resources: AgentResources
|
||||
permissions: dict[str, list[str]]
|
||||
status: AgentRunSessionStatus
|
||||
_authorized_ids: dict[str, set[str]] # Pre-computed sets for O(1) lookup
|
||||
|
||||
@@ -67,6 +69,7 @@ class AgentRunSessionRegistry:
|
||||
plugin_identity: str,
|
||||
resources: AgentResources,
|
||||
conversation_id: str | None = None,
|
||||
permissions: dict[str, list[str]] | None = None,
|
||||
) -> None:
|
||||
"""Register a new agent run session.
|
||||
|
||||
@@ -77,9 +80,13 @@ class AgentRunSessionRegistry:
|
||||
plugin_identity: Plugin identifier (author/name)
|
||||
resources: Authorized resources for this run
|
||||
conversation_id: Conversation ID for history/event access
|
||||
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
|
||||
"""
|
||||
now = int(time.time())
|
||||
|
||||
# Normalize permissions to empty dict if None
|
||||
permissions = permissions or {}
|
||||
|
||||
# Pre-compute authorized resource IDs for O(1) lookup
|
||||
authorized_ids: dict[str, set[str]] = {
|
||||
'model': {m.get('model_id') for m in resources.get('models', [])},
|
||||
@@ -95,6 +102,7 @@ class AgentRunSessionRegistry:
|
||||
'plugin_identity': plugin_identity,
|
||||
'conversation_id': conversation_id,
|
||||
'resources': resources,
|
||||
'permissions': permissions,
|
||||
'status': {
|
||||
'started_at': now,
|
||||
'last_activity_at': now,
|
||||
|
||||
@@ -7,7 +7,8 @@ import typing
|
||||
import uuid
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from ...entity.persistence.transcript import Transcript
|
||||
|
||||
@@ -27,6 +28,9 @@ class TranscriptStore:
|
||||
|
||||
def __init__(self, engine: AsyncEngine):
|
||||
self.engine = engine
|
||||
self._session_factory = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def append_transcript(
|
||||
self,
|
||||
@@ -72,9 +76,8 @@ class TranscriptStore:
|
||||
# Get next sequence number for this conversation
|
||||
seq = await self._get_next_seq(conversation_id)
|
||||
|
||||
async with self.engine.connect() as conn:
|
||||
await conn.execute(
|
||||
sqlalchemy.insert(Transcript).values(
|
||||
async with self._session_factory() as session:
|
||||
item = Transcript(
|
||||
transcript_id=transcript_id,
|
||||
event_id=event_id,
|
||||
conversation_id=conversation_id,
|
||||
@@ -90,8 +93,8 @@ class TranscriptStore:
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
)
|
||||
await conn.commit()
|
||||
session.add(item)
|
||||
await session.commit()
|
||||
|
||||
return transcript_id
|
||||
|
||||
@@ -119,7 +122,7 @@ class TranscriptStore:
|
||||
"""
|
||||
limit = min(limit, self.HARD_LIMIT)
|
||||
|
||||
async with self.engine.connect() as conn:
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(Transcript).where(
|
||||
Transcript.conversation_id == conversation_id
|
||||
)
|
||||
@@ -136,10 +139,10 @@ class TranscriptStore:
|
||||
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await conn.execute(query)
|
||||
rows = result.fetchall()
|
||||
result = await session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
items = [self._row_to_dict(row[0], include_artifacts) for row in rows[:limit]]
|
||||
items = [self._row_to_dict(row, include_artifacts) for row in rows[:limit]]
|
||||
has_more = len(rows) > limit
|
||||
|
||||
# Calculate cursors
|
||||
@@ -179,7 +182,7 @@ class TranscriptStore:
|
||||
Returns:
|
||||
List of matching items
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(Transcript).where(
|
||||
Transcript.conversation_id == conversation_id,
|
||||
Transcript.content.ilike(f"%{query_text}%"),
|
||||
@@ -194,10 +197,10 @@ class TranscriptStore:
|
||||
|
||||
query = query.order_by(Transcript.seq.desc()).limit(top_k)
|
||||
|
||||
result = await conn.execute(query)
|
||||
rows = result.fetchall()
|
||||
result = await session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
return [self._row_to_dict(row[0], include_artifacts=True) for row in rows]
|
||||
return [self._row_to_dict(row, include_artifacts=True) for row in rows]
|
||||
|
||||
async def get_latest_cursor(
|
||||
self,
|
||||
@@ -211,17 +214,17 @@ class TranscriptStore:
|
||||
Returns:
|
||||
Cursor string (seq number), or None if no items
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(Transcript.seq)
|
||||
.where(Transcript.conversation_id == conversation_id)
|
||||
.order_by(Transcript.seq.desc())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.fetchone()
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return str(row[0])
|
||||
return str(row)
|
||||
|
||||
async def has_history_before(
|
||||
self,
|
||||
@@ -237,8 +240,8 @@ class TranscriptStore:
|
||||
Returns:
|
||||
True if there are items before
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(sqlalchemy.func.count())
|
||||
.select_from(Transcript)
|
||||
.where(
|
||||
@@ -251,8 +254,8 @@ class TranscriptStore:
|
||||
|
||||
async def _get_next_seq(self, conversation_id: str) -> int:
|
||||
"""Get the next sequence number for a conversation."""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(sqlalchemy.func.max(Transcript.seq))
|
||||
.where(Transcript.conversation_id == conversation_id)
|
||||
)
|
||||
|
||||
77
src/langbot/pkg/entity/persistence/artifact.py
Normal file
77
src/langbot/pkg/entity/persistence/artifact.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Artifact persistence entity for Host-owned artifact store."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy
|
||||
import datetime
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
class AgentArtifact(Base):
|
||||
"""AgentArtifact stores metadata for large files, images, tool results, etc.
|
||||
|
||||
This table only stores metadata. The actual blob content is stored in
|
||||
BinaryStorage or external storage, referenced by storage_key.
|
||||
|
||||
Artifacts are accessed via artifact_metadata and artifact_read APIs
|
||||
with run_id authorization.
|
||||
"""
|
||||
|
||||
__tablename__ = 'agent_artifact'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||
"""Auto-increment ID for sequencing."""
|
||||
|
||||
artifact_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True)
|
||||
"""Unique artifact identifier."""
|
||||
|
||||
artifact_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False)
|
||||
"""Artifact type: 'image', 'file', 'voice', 'tool_result', 'platform_attachment', etc."""
|
||||
|
||||
mime_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""MIME type of the content."""
|
||||
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Original file name (if applicable)."""
|
||||
|
||||
size_bytes = sqlalchemy.Column(sqlalchemy.BigInteger, nullable=True)
|
||||
"""Size in bytes."""
|
||||
|
||||
sha256 = sqlalchemy.Column(sqlalchemy.String(64), nullable=True)
|
||||
"""SHA256 hash of content (for integrity verification)."""
|
||||
|
||||
source = sqlalchemy.Column(sqlalchemy.String(50), nullable=False)
|
||||
"""Source of artifact: 'platform', 'runner', 'tool', 'system'."""
|
||||
|
||||
# Storage reference (points to BinaryStorage or external storage)
|
||||
storage_key = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Key in BinaryStorage or external storage reference."""
|
||||
|
||||
storage_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, default='binary_storage')
|
||||
"""Storage type: 'binary_storage', 'file', 'url', etc."""
|
||||
|
||||
# Context
|
||||
conversation_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Conversation this artifact belongs to."""
|
||||
|
||||
run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||
"""Run ID that created this artifact."""
|
||||
|
||||
runner_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Runner ID that created this artifact."""
|
||||
|
||||
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Bot UUID that handled this artifact."""
|
||||
|
||||
workspace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
"""Workspace ID for multi-tenant deployments."""
|
||||
|
||||
# Lifecycle
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow)
|
||||
"""When this artifact was created."""
|
||||
|
||||
expires_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||
"""When this artifact expires (optional)."""
|
||||
|
||||
metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||
"""Additional metadata as JSON string."""
|
||||
@@ -17,6 +17,7 @@ from langbot.pkg.entity.persistence.base import Base
|
||||
# This is required for autogenerate to detect model changes
|
||||
from langbot.pkg.entity.persistence import (
|
||||
apikey,
|
||||
artifact,
|
||||
bot,
|
||||
bstorage,
|
||||
event_log,
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
"""add_agent_artifact_table
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 58846a8d7a81
|
||||
Create Date: 2026-05-23 20:00:00.000000
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = 'a1b2c3d4e5f6'
|
||||
down_revision = '58846a8d7a81'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create agent_artifact table
|
||||
op.create_table(
|
||||
'agent_artifact',
|
||||
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column('artifact_id', sa.String(255), nullable=False, unique=True),
|
||||
sa.Column('artifact_type', sa.String(50), nullable=False),
|
||||
sa.Column('mime_type', sa.String(255), nullable=True),
|
||||
sa.Column('name', sa.String(255), nullable=True),
|
||||
sa.Column('size_bytes', sa.BigInteger(), nullable=True),
|
||||
sa.Column('sha256', sa.String(64), nullable=True),
|
||||
sa.Column('source', sa.String(50), nullable=False),
|
||||
sa.Column('storage_key', sa.String(255), nullable=True),
|
||||
sa.Column('storage_type', sa.String(50), nullable=False, server_default='binary_storage'),
|
||||
sa.Column('conversation_id', sa.String(255), nullable=True),
|
||||
sa.Column('run_id', sa.String(255), nullable=True),
|
||||
sa.Column('runner_id', sa.String(255), nullable=True),
|
||||
sa.Column('bot_id', sa.String(255), nullable=True),
|
||||
sa.Column('workspace_id', sa.String(255), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('metadata_json', sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
# Create indexes for agent_artifact
|
||||
with op.batch_alter_table('agent_artifact', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_agent_artifact_artifact_id', ['artifact_id'], unique=True)
|
||||
batch_op.create_index('ix_agent_artifact_conversation_id', ['conversation_id'], unique=False)
|
||||
batch_op.create_index('ix_agent_artifact_run_id', ['run_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop agent_artifact table
|
||||
with op.batch_alter_table('agent_artifact', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_agent_artifact_run_id')
|
||||
batch_op.drop_index('ix_agent_artifact_conversation_id')
|
||||
batch_op.drop_index('ix_agent_artifact_artifact_id')
|
||||
|
||||
op.drop_table('agent_artifact')
|
||||
@@ -100,6 +100,47 @@ def _build_tool_detail(tool: Any, requested_tool_name: str | None = None) -> dic
|
||||
}
|
||||
|
||||
|
||||
def _validate_artifact_access(
|
||||
session: dict[str, Any],
|
||||
artifact_metadata: dict[str, Any],
|
||||
operation: str,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Validate artifact access for a run session.
|
||||
|
||||
Authorization rules (evaluated in order, first match wins):
|
||||
1. Artifact run_id matches session run_id → ALLOW (created by this run)
|
||||
2. Artifact has conversation_id AND matches session conversation_id → ALLOW (same conversation)
|
||||
3. Otherwise → DENY
|
||||
|
||||
Note: Artifacts without conversation_id are NOT globally accessible by default.
|
||||
Without an explicit scope field, we enforce strict access control.
|
||||
|
||||
Args:
|
||||
session: AgentRunSession dict with run_id, conversation_id, permissions
|
||||
artifact_metadata: Artifact metadata dict with conversation_id, run_id
|
||||
operation: Operation name for error messages ('metadata' or 'read')
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, error_message). If is_allowed is False, error_message contains reason.
|
||||
"""
|
||||
artifact_conversation_id = artifact_metadata.get('conversation_id')
|
||||
artifact_run_id = artifact_metadata.get('run_id')
|
||||
session_conversation_id = session.get('conversation_id')
|
||||
session_run_id = session.get('run_id')
|
||||
|
||||
# Rule 1: Created by this run (allows cross-conversation access for self-created artifacts)
|
||||
if artifact_run_id and artifact_run_id == session_run_id:
|
||||
return True, None
|
||||
|
||||
# Rule 2: Same conversation (requires artifact to have conversation_id)
|
||||
if artifact_conversation_id and session_conversation_id:
|
||||
if artifact_conversation_id == session_conversation_id:
|
||||
return True, None
|
||||
|
||||
# Rule 3: Deny - no matching authorization rule
|
||||
return False, f'Artifact {operation} access denied: artifact not in session conversation and not created by this run'
|
||||
|
||||
|
||||
def _normalize_uuid_list(values: Any) -> list[str]:
|
||||
"""Normalize a user/config supplied UUID list while preserving order."""
|
||||
if not isinstance(values, list):
|
||||
@@ -1542,6 +1583,169 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
self.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'Event page error: {e}')
|
||||
|
||||
# ================= Artifact APIs =================
|
||||
|
||||
@self.action(PluginToRuntimeAction.ARTIFACT_METADATA)
|
||||
async def artifact_metadata(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Get artifact metadata.
|
||||
|
||||
Requires run_id authorization. Only allows access to artifacts
|
||||
in current run's conversation or created by current run.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
artifact_id = data.get('artifact_id')
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
if not artifact_id:
|
||||
return handler.ActionResponse.error(message='artifact_id is required')
|
||||
|
||||
# Validate run session
|
||||
session_registry = get_session_registry()
|
||||
session = await session_registry.get(run_id)
|
||||
if not session:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Run session {run_id} not found or expired'
|
||||
)
|
||||
|
||||
# Validate caller plugin identity
|
||||
if caller_plugin_identity:
|
||||
session_plugin_identity = session.get('plugin_identity')
|
||||
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
||||
)
|
||||
|
||||
# Check artifact permission from session.permissions (from descriptor.permissions)
|
||||
permissions = session.get('permissions', {})
|
||||
artifact_permissions = permissions.get('artifacts', [])
|
||||
if 'metadata' not in artifact_permissions:
|
||||
return handler.ActionResponse.error(
|
||||
message='Artifact metadata access not authorized'
|
||||
)
|
||||
|
||||
# Get artifact metadata
|
||||
from ..agent.runner.artifact_store import ArtifactStore
|
||||
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
metadata = await store.get_metadata(artifact_id)
|
||||
if not metadata:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Artifact {artifact_id} not found'
|
||||
)
|
||||
|
||||
# Validate artifact access scope
|
||||
is_allowed, error_msg = _validate_artifact_access(session, metadata, 'metadata')
|
||||
if not is_allowed:
|
||||
return handler.ActionResponse.error(message=error_msg)
|
||||
|
||||
return handler.ActionResponse.success(data=metadata)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'ARTIFACT_METADATA error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'Artifact metadata error: {e}')
|
||||
|
||||
@self.action(PluginToRuntimeAction.ARTIFACT_READ)
|
||||
async def artifact_read(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Read artifact content.
|
||||
|
||||
Requires run_id authorization. Only allows access to artifacts
|
||||
in current run's conversation or created by current run.
|
||||
Supports range reads with offset/limit.
|
||||
"""
|
||||
run_id = data.get('run_id')
|
||||
artifact_id = data.get('artifact_id')
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
if not artifact_id:
|
||||
return handler.ActionResponse.error(message='artifact_id is required')
|
||||
|
||||
# Validate and parse offset
|
||||
offset = data.get('offset', 0)
|
||||
if not isinstance(offset, int):
|
||||
try:
|
||||
offset = int(offset)
|
||||
except (TypeError, ValueError):
|
||||
return handler.ActionResponse.error(message='offset must be an integer')
|
||||
if offset < 0:
|
||||
return handler.ActionResponse.error(message='offset must be >= 0')
|
||||
|
||||
# Validate and parse limit if provided
|
||||
limit = data.get('limit')
|
||||
if limit is not None:
|
||||
if not isinstance(limit, int):
|
||||
try:
|
||||
limit = int(limit)
|
||||
except (TypeError, ValueError):
|
||||
return handler.ActionResponse.error(message='limit must be an integer')
|
||||
if limit <= 0:
|
||||
return handler.ActionResponse.error(message='limit must be > 0')
|
||||
|
||||
# Validate run session
|
||||
session_registry = get_session_registry()
|
||||
session = await session_registry.get(run_id)
|
||||
if not session:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Run session {run_id} not found or expired'
|
||||
)
|
||||
|
||||
# Validate caller plugin identity
|
||||
if caller_plugin_identity:
|
||||
session_plugin_identity = session.get('plugin_identity')
|
||||
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
||||
)
|
||||
|
||||
# Check artifact permission from session.permissions (from descriptor.permissions)
|
||||
permissions = session.get('permissions', {})
|
||||
artifact_permissions = permissions.get('artifacts', [])
|
||||
if 'read' not in artifact_permissions:
|
||||
return handler.ActionResponse.error(
|
||||
message='Artifact read access not authorized'
|
||||
)
|
||||
|
||||
# Get artifact metadata first to validate access
|
||||
from ..agent.runner.artifact_store import ArtifactStore
|
||||
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
metadata = await store.get_metadata(artifact_id)
|
||||
if not metadata:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Artifact {artifact_id} not found'
|
||||
)
|
||||
|
||||
# Validate artifact access scope
|
||||
is_allowed, error_msg = _validate_artifact_access(session, metadata, 'read')
|
||||
if not is_allowed:
|
||||
return handler.ActionResponse.error(message=error_msg)
|
||||
|
||||
# Read artifact content (validates offset/limit internally)
|
||||
result = await store.read_artifact(
|
||||
artifact_id=artifact_id,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not result:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Failed to read artifact {artifact_id}'
|
||||
)
|
||||
|
||||
return handler.ActionResponse.success(data=result)
|
||||
except ValueError as e:
|
||||
# Offset/limit validation error
|
||||
return handler.ActionResponse.error(message=str(e))
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'ARTIFACT_READ error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'Artifact read error: {e}')
|
||||
|
||||
@self.action(CommonAction.PING)
|
||||
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Ping"""
|
||||
|
||||
625
tests/unit_tests/agent/test_artifact_store.py
Normal file
625
tests/unit_tests/agent/test_artifact_store.py
Normal file
@@ -0,0 +1,625 @@
|
||||
"""Tests for ArtifactStore and artifact action handlers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import base64
|
||||
import datetime
|
||||
import asyncio
|
||||
|
||||
from langbot.pkg.agent.runner.artifact_store import ArtifactStore
|
||||
from langbot.pkg.agent.runner.session_registry import (
|
||||
AgentRunSessionRegistry,
|
||||
get_session_registry,
|
||||
)
|
||||
|
||||
|
||||
class TestArtifactStore:
|
||||
"""Test ArtifactStore operations."""
|
||||
|
||||
def _make_mock_engine(self):
|
||||
"""Create a mock database engine for AsyncSession-based store.
|
||||
|
||||
Note: The new store uses AsyncSession, so we need to mock
|
||||
the session factory behavior.
|
||||
"""
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
engine = MagicMock(spec=AsyncEngine)
|
||||
return engine
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_artifact_generates_id(self):
|
||||
"""Test register_artifact generates ID if not provided."""
|
||||
engine = self._make_mock_engine()
|
||||
store = ArtifactStore(engine)
|
||||
|
||||
# Mock the session factory
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
artifact_id = await store.register_artifact(
|
||||
artifact_id=None,
|
||||
artifact_type="image",
|
||||
source="platform",
|
||||
)
|
||||
|
||||
assert artifact_id is not None
|
||||
assert len(artifact_id) == 36 # UUID format
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_artifact_with_content(self):
|
||||
"""Test register_artifact stores content in BinaryStorage."""
|
||||
engine = self._make_mock_engine()
|
||||
store = ArtifactStore(engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
content = b"test image content"
|
||||
artifact_id = await store.register_artifact(
|
||||
artifact_id="art_001",
|
||||
artifact_type="image",
|
||||
source="platform",
|
||||
content=content,
|
||||
)
|
||||
|
||||
assert artifact_id == "art_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_artifact_with_storage_key(self):
|
||||
"""Test register_artifact with pre-existing storage_key."""
|
||||
engine = self._make_mock_engine()
|
||||
store = ArtifactStore(engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
artifact_id = await store.register_artifact(
|
||||
artifact_id="art_002",
|
||||
artifact_type="file",
|
||||
source="runner",
|
||||
storage_key="existing_key",
|
||||
storage_type="binary_storage",
|
||||
size_bytes=1024,
|
||||
)
|
||||
|
||||
assert artifact_id == "art_002"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_not_found(self):
|
||||
"""Test get_metadata returns None if not found."""
|
||||
engine = self._make_mock_engine()
|
||||
store = ArtifactStore(engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
metadata = await store.get_metadata("nonexistent")
|
||||
|
||||
assert metadata is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_artifact_validates_offset(self):
|
||||
"""Test read_artifact rejects negative offset."""
|
||||
engine = self._make_mock_engine()
|
||||
store = ArtifactStore(engine)
|
||||
|
||||
with pytest.raises(ValueError, match="offset must be >= 0"):
|
||||
await store.read_artifact("art_001", offset=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_artifact_validates_limit(self):
|
||||
"""Test read_artifact rejects zero or negative limit."""
|
||||
engine = self._make_mock_engine()
|
||||
store = ArtifactStore(engine)
|
||||
|
||||
with pytest.raises(ValueError, match="limit must be > 0"):
|
||||
await store.read_artifact("art_001", limit=0)
|
||||
|
||||
with pytest.raises(ValueError, match="limit must be > 0"):
|
||||
await store.read_artifact("art_001", limit=-5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_artifact_not_found(self):
|
||||
"""Test read_artifact returns None if not found."""
|
||||
engine = self._make_mock_engine()
|
||||
store = ArtifactStore(engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
result = await store.read_artifact("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestArtifactAuthorization:
|
||||
"""Test artifact action handler authorization."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_registry(self):
|
||||
"""Create a fresh session registry for testing."""
|
||||
# Reset global registry
|
||||
import langbot.pkg.agent.runner.session_registry as reg
|
||||
reg._global_registry = None
|
||||
return get_session_registry()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_handler(self):
|
||||
"""Create a mock handler for testing actions."""
|
||||
from langbot_plugin.runtime.io.handler import Handler
|
||||
|
||||
class MockHandler(Handler):
|
||||
def __init__(self):
|
||||
self._responses = {}
|
||||
|
||||
async def call_action(self, action, data, timeout=30):
|
||||
# Simulate error response for missing run_id
|
||||
if not data.get("run_id"):
|
||||
return {"ok": False, "message": "run_id is required"}
|
||||
return {"ok": True, "data": {}}
|
||||
|
||||
return MockHandler()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_artifact_metadata_requires_run_id(self, mock_handler):
|
||||
"""Test artifact_metadata requires run_id."""
|
||||
result = await mock_handler.call_action(
|
||||
"artifact_metadata",
|
||||
{"run_id": None, "artifact_id": "art_001"},
|
||||
)
|
||||
|
||||
assert result.get("ok") is False or "error" in str(result).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_artifact_read_requires_run_id(self, mock_handler):
|
||||
"""Test artifact_read requires run_id."""
|
||||
result = await mock_handler.call_action(
|
||||
"artifact_read",
|
||||
{"run_id": None, "artifact_id": "art_001"},
|
||||
)
|
||||
|
||||
assert result.get("ok") is False or "error" in str(result).lower()
|
||||
|
||||
|
||||
class TestArtifactAccessValidation:
|
||||
"""Test _validate_artifact_access authorization rules."""
|
||||
|
||||
def _call_validate(self, session, metadata, operation="metadata"):
|
||||
"""Helper to call the validation function."""
|
||||
from langbot.pkg.plugin.handler import _validate_artifact_access
|
||||
return _validate_artifact_access(session, metadata, operation)
|
||||
|
||||
def test_global_artifact_denied_by_default(self):
|
||||
"""Artifacts without conversation_id are denied by default (no global access)."""
|
||||
session = {
|
||||
"run_id": "run_001",
|
||||
"conversation_id": "conv_001",
|
||||
"permissions": {"artifacts": ["metadata", "read"]},
|
||||
}
|
||||
metadata = {
|
||||
"artifact_id": "art_global",
|
||||
"conversation_id": None, # No conversation scope
|
||||
"run_id": None, # Not created by any run
|
||||
}
|
||||
|
||||
is_allowed, error = self._call_validate(session, metadata)
|
||||
assert is_allowed is False
|
||||
assert "denied" in error.lower()
|
||||
|
||||
def test_own_run_artifact_allowed(self):
|
||||
"""Artifacts created by same run are allowed (even cross-conversation)."""
|
||||
session = {
|
||||
"run_id": "run_001",
|
||||
"conversation_id": "conv_001",
|
||||
"permissions": {"artifacts": ["metadata", "read"]},
|
||||
}
|
||||
metadata = {
|
||||
"artifact_id": "art_001",
|
||||
"conversation_id": "conv_other", # Different conversation
|
||||
"run_id": "run_001", # Same run
|
||||
}
|
||||
|
||||
is_allowed, error = self._call_validate(session, metadata)
|
||||
assert is_allowed is True
|
||||
assert error is None
|
||||
|
||||
def test_same_conversation_allowed(self):
|
||||
"""Artifacts in same conversation are allowed."""
|
||||
session = {
|
||||
"run_id": "run_001",
|
||||
"conversation_id": "conv_001",
|
||||
"permissions": {"artifacts": ["metadata", "read"]},
|
||||
}
|
||||
metadata = {
|
||||
"artifact_id": "art_001",
|
||||
"conversation_id": "conv_001", # Same as session
|
||||
"run_id": "run_other", # Different run
|
||||
}
|
||||
|
||||
is_allowed, error = self._call_validate(session, metadata)
|
||||
assert is_allowed is True
|
||||
assert error is None
|
||||
|
||||
def test_different_conversation_and_run_denied(self):
|
||||
"""Artifacts in different conversation and different run are denied."""
|
||||
session = {
|
||||
"run_id": "run_001",
|
||||
"conversation_id": "conv_001",
|
||||
"permissions": {"artifacts": ["metadata", "read"]},
|
||||
}
|
||||
metadata = {
|
||||
"artifact_id": "art_001",
|
||||
"conversation_id": "conv_other", # Different conversation
|
||||
"run_id": "run_other", # Different run
|
||||
}
|
||||
|
||||
is_allowed, error = self._call_validate(session, metadata)
|
||||
assert is_allowed is False
|
||||
assert "denied" in error.lower()
|
||||
|
||||
def test_session_without_conversation_denied_for_conversation_artifact(self):
|
||||
"""Session without conversation_id cannot access conversation-scoped artifacts."""
|
||||
session = {
|
||||
"run_id": "run_001",
|
||||
"conversation_id": None, # No conversation
|
||||
"permissions": {"artifacts": ["metadata", "read"]},
|
||||
}
|
||||
metadata = {
|
||||
"artifact_id": "art_001",
|
||||
"conversation_id": "conv_001", # Has conversation
|
||||
"run_id": "run_other", # Different run
|
||||
}
|
||||
|
||||
is_allowed, error = self._call_validate(session, metadata)
|
||||
assert is_allowed is False
|
||||
|
||||
def test_session_without_conversation_allowed_for_own_artifact(self):
|
||||
"""Session without conversation can access artifacts it created."""
|
||||
session = {
|
||||
"run_id": "run_001",
|
||||
"conversation_id": None, # No conversation
|
||||
"permissions": {"artifacts": ["metadata", "read"]},
|
||||
}
|
||||
metadata = {
|
||||
"artifact_id": "art_001",
|
||||
"conversation_id": "conv_001", # Has conversation
|
||||
"run_id": "run_001", # Same run (created by this run)
|
||||
}
|
||||
|
||||
is_allowed, error = self._call_validate(session, metadata)
|
||||
assert is_allowed is True
|
||||
|
||||
|
||||
class TestContextAccessArtifactAPIs:
|
||||
"""Test ContextAccess reflects artifact API permissions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_has_artifact_apis_when_permitted(self):
|
||||
"""Test ContextAccess shows artifact APIs when permissions allow."""
|
||||
# This tests the context builder logic
|
||||
# When artifact permissions include 'metadata' and 'read',
|
||||
# available_apis should reflect that
|
||||
permissions = {"artifacts": ["metadata", "read"]}
|
||||
|
||||
# Check that permissions are properly interpreted
|
||||
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
|
||||
artifact_read_enabled = "read" in permissions.get("artifacts", [])
|
||||
|
||||
assert artifact_metadata_enabled is True
|
||||
assert artifact_read_enabled is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_no_artifact_apis_without_permission(self):
|
||||
"""Test ContextAccess hides artifact APIs when permissions denied."""
|
||||
permissions = {"artifacts": []}
|
||||
|
||||
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
|
||||
artifact_read_enabled = "read" in permissions.get("artifacts", [])
|
||||
|
||||
assert artifact_metadata_enabled is False
|
||||
assert artifact_read_enabled is False
|
||||
|
||||
|
||||
class TestArtifactMetadataFieldAlignment:
|
||||
"""Test that Host returns metadata compatible with SDK ArtifactMetadata."""
|
||||
|
||||
def test_row_to_public_dict_excludes_host_only_fields(self):
|
||||
"""_row_to_public_dict should not return Host-only fields."""
|
||||
from langbot.pkg.agent.runner.artifact_store import ArtifactStore
|
||||
from langbot.pkg.entity.persistence.artifact import AgentArtifact
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create a mock row
|
||||
mock_row = MagicMock(spec=AgentArtifact)
|
||||
mock_row.artifact_id = "art_001"
|
||||
mock_row.artifact_type = "image"
|
||||
mock_row.mime_type = "image/png"
|
||||
mock_row.name = "test.png"
|
||||
mock_row.size_bytes = 1024
|
||||
mock_row.sha256 = "abc123"
|
||||
mock_row.source = "platform"
|
||||
mock_row.conversation_id = "conv_001"
|
||||
mock_row.run_id = "run_001"
|
||||
mock_row.runner_id = "plugin:test/plugin/runner"
|
||||
mock_row.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||
mock_row.expires_at = None
|
||||
mock_row.metadata_json = None
|
||||
|
||||
# These are Host-only fields that should NOT be in output
|
||||
# (they don't exist in SDK ArtifactMetadata)
|
||||
mock_row.bot_id = "bot_001"
|
||||
mock_row.workspace_id = "ws_001"
|
||||
mock_row.storage_key = "artifact:art_001"
|
||||
mock_row.storage_type = "binary_storage"
|
||||
|
||||
store = ArtifactStore(MagicMock())
|
||||
result = store._row_to_public_dict(mock_row)
|
||||
|
||||
# SDK-compatible fields should be present
|
||||
assert result["artifact_id"] == "art_001"
|
||||
assert result["artifact_type"] == "image"
|
||||
assert result["source"] == "platform"
|
||||
assert result["conversation_id"] == "conv_001"
|
||||
assert result["run_id"] == "run_001"
|
||||
|
||||
# Host-only fields should NOT be present
|
||||
assert "bot_id" not in result
|
||||
assert "workspace_id" not in result
|
||||
assert "storage_key" not in result
|
||||
assert "storage_type" not in result
|
||||
|
||||
|
||||
class TestSessionRegistryPermissions:
|
||||
"""Test that session registry stores and retrieves permissions correctly."""
|
||||
|
||||
@pytest.fixture
|
||||
def session_registry(self):
|
||||
"""Create a fresh session registry for testing."""
|
||||
import langbot.pkg.agent.runner.session_registry as reg
|
||||
reg._global_registry = None
|
||||
return get_session_registry()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_stores_permissions(self, session_registry):
|
||||
"""Test that register() stores permissions from descriptor."""
|
||||
await session_registry.register(
|
||||
run_id="run_001",
|
||||
runner_id="plugin:author/plugin/runner",
|
||||
query_id=None,
|
||||
plugin_identity="author/plugin",
|
||||
resources={
|
||||
"models": [],
|
||||
"tools": [],
|
||||
"knowledge_bases": [],
|
||||
"files": [],
|
||||
"storage": {"plugin_storage": True, "workspace_storage": False},
|
||||
"platform_capabilities": {},
|
||||
},
|
||||
permissions={
|
||||
"artifacts": ["metadata", "read"],
|
||||
"history": ["page"],
|
||||
"events": ["get"],
|
||||
},
|
||||
conversation_id="conv_001",
|
||||
)
|
||||
|
||||
session = await session_registry.get("run_001")
|
||||
assert session is not None
|
||||
assert session["permissions"]["artifacts"] == ["metadata", "read"]
|
||||
assert session["permissions"]["history"] == ["page"]
|
||||
assert session["permissions"]["events"] == ["get"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_empty_permissions(self, session_registry):
|
||||
"""Test that register() handles empty permissions."""
|
||||
await session_registry.register(
|
||||
run_id="run_002",
|
||||
runner_id="plugin:author/plugin/runner",
|
||||
query_id=None,
|
||||
plugin_identity="author/plugin",
|
||||
resources={
|
||||
"models": [],
|
||||
"tools": [],
|
||||
"knowledge_bases": [],
|
||||
"files": [],
|
||||
"storage": {"plugin_storage": True, "workspace_storage": False},
|
||||
"platform_capabilities": {},
|
||||
},
|
||||
permissions={},
|
||||
conversation_id="conv_001",
|
||||
)
|
||||
|
||||
session = await session_registry.get("run_002")
|
||||
assert session is not None
|
||||
assert session["permissions"] == {}
|
||||
|
||||
|
||||
class TestArtifactStoreRealSQLite:
|
||||
"""Test ArtifactStore with real SQLite database."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy import text
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.entity.persistence.artifact import AgentArtifact
|
||||
from langbot.pkg.entity.persistence.bstorage import BinaryStorage
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
# Create tables manually for in-memory DB
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_get_metadata_round_trip(self, db_engine):
|
||||
"""Test register_artifact -> get_metadata round trip with real DB."""
|
||||
store = ArtifactStore(db_engine)
|
||||
|
||||
# Register artifact with content
|
||||
content = b"test image content for round trip"
|
||||
artifact_id = await store.register_artifact(
|
||||
artifact_id="art_real_001",
|
||||
artifact_type="image",
|
||||
source="platform",
|
||||
mime_type="image/png",
|
||||
name="test.png",
|
||||
content=content,
|
||||
conversation_id="conv_001",
|
||||
run_id="run_001",
|
||||
)
|
||||
|
||||
assert artifact_id == "art_real_001"
|
||||
|
||||
# Get metadata
|
||||
metadata = await store.get_metadata(artifact_id)
|
||||
assert metadata is not None
|
||||
assert metadata["artifact_id"] == "art_real_001"
|
||||
assert metadata["artifact_type"] == "image"
|
||||
assert metadata["mime_type"] == "image/png"
|
||||
assert metadata["source"] == "platform"
|
||||
assert metadata["conversation_id"] == "conv_001"
|
||||
assert metadata["run_id"] == "run_001"
|
||||
|
||||
# Verify Host-only fields are NOT in public metadata
|
||||
assert "storage_key" not in metadata
|
||||
assert "storage_type" not in metadata
|
||||
assert "bot_id" not in metadata
|
||||
assert "workspace_id" not in metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_artifact_round_trip(self, db_engine):
|
||||
"""Test register_artifact -> read_artifact round trip with real DB."""
|
||||
store = ArtifactStore(db_engine)
|
||||
|
||||
# Register artifact with content
|
||||
content = b"test file content for read test"
|
||||
artifact_id = await store.register_artifact(
|
||||
artifact_id="art_real_002",
|
||||
artifact_type="file",
|
||||
source="runner",
|
||||
mime_type="text/plain",
|
||||
name="test.txt",
|
||||
content=content,
|
||||
conversation_id="conv_001",
|
||||
run_id="run_001",
|
||||
)
|
||||
|
||||
# Read artifact
|
||||
result = await store.read_artifact(artifact_id)
|
||||
assert result is not None
|
||||
assert result["artifact_id"] == "art_real_002"
|
||||
assert result["mime_type"] == "text/plain"
|
||||
assert result["offset"] == 0
|
||||
assert result["length"] == len(content)
|
||||
assert result["has_more"] is False
|
||||
|
||||
# Verify content
|
||||
decoded_content = base64.b64decode(result["content_base64"])
|
||||
assert decoded_content == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_artifact_with_offset_limit(self, db_engine):
|
||||
"""Test read_artifact with offset and limit."""
|
||||
store = ArtifactStore(db_engine)
|
||||
|
||||
# Register artifact with content
|
||||
content = b"0123456789" * 100 # 1000 bytes
|
||||
artifact_id = await store.register_artifact(
|
||||
artifact_id="art_real_003",
|
||||
artifact_type="file",
|
||||
source="runner",
|
||||
mime_type="application/octet-stream",
|
||||
content=content,
|
||||
)
|
||||
|
||||
# Read with offset
|
||||
result = await store.read_artifact(artifact_id, offset=100, limit=100)
|
||||
assert result is not None
|
||||
assert result["offset"] == 100
|
||||
assert result["length"] == 100
|
||||
|
||||
# Verify content
|
||||
decoded_content = base64.b64decode(result["content_base64"])
|
||||
assert decoded_content == content[100:200]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_artifact_has_more(self, db_engine):
|
||||
"""Test read_artifact sets has_more correctly."""
|
||||
store = ArtifactStore(db_engine)
|
||||
|
||||
# Register artifact with content
|
||||
content = b"0123456789" * 100 # 1000 bytes
|
||||
artifact_id = await store.register_artifact(
|
||||
artifact_id="art_real_004",
|
||||
artifact_type="file",
|
||||
source="runner",
|
||||
content=content,
|
||||
)
|
||||
|
||||
# Read with limit smaller than content
|
||||
result = await store.read_artifact(artifact_id, offset=0, limit=100)
|
||||
assert result is not None
|
||||
assert result["has_more"] is True
|
||||
assert result["length"] == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_sdk_validation(self, db_engine):
|
||||
"""Test that metadata can be validated by SDK ArtifactMetadata."""
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.artifact import ArtifactMetadata
|
||||
|
||||
store = ArtifactStore(db_engine)
|
||||
|
||||
# Register artifact
|
||||
artifact_id = await store.register_artifact(
|
||||
artifact_id="art_real_005",
|
||||
artifact_type="file",
|
||||
source="runner",
|
||||
mime_type="application/pdf",
|
||||
name="document.pdf",
|
||||
size_bytes=1024,
|
||||
conversation_id="conv_001",
|
||||
run_id="run_001",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
)
|
||||
|
||||
# Get metadata
|
||||
metadata = await store.get_metadata(artifact_id)
|
||||
assert metadata is not None
|
||||
|
||||
# Should not raise ValidationError
|
||||
validated = ArtifactMetadata.model_validate(metadata)
|
||||
assert validated.artifact_id == "art_real_005"
|
||||
assert validated.artifact_type == "file"
|
||||
@@ -73,8 +73,17 @@ class TestEventLogStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event(self, mock_db_engine):
|
||||
"""Test appending an event to EventLog."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_1",
|
||||
event_type="message.received",
|
||||
@@ -93,8 +102,17 @@ class TestEventLogStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_event_truncates_input_summary(self, mock_db_engine):
|
||||
"""Test that long input summaries are truncated."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
long_text = "x" * 2000
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_2",
|
||||
@@ -108,8 +126,19 @@ class TestEventLogStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_events_with_conversation_filter(self, mock_db_engine):
|
||||
"""Test paging events with conversation_id filter."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = EventLogStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
items, next_seq, has_more = await store.page_events(
|
||||
conversation_id="conv_1",
|
||||
limit=10,
|
||||
@@ -124,8 +153,19 @@ class TestTranscriptStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_transcript(self, mock_db_engine):
|
||||
"""Test appending a transcript item."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
# Mock _get_next_seq
|
||||
with patch.object(store, '_get_next_seq', return_value=1):
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
transcript_id = await store.append_transcript(
|
||||
transcript_id=None, # Auto-generate
|
||||
event_id="evt_1",
|
||||
@@ -139,8 +179,18 @@ class TestTranscriptStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_transcript_with_artifacts(self, mock_db_engine):
|
||||
"""Test appending transcript with artifact refs."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with patch.object(store, '_get_next_seq', return_value=1):
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
transcript_id = await store.append_transcript(
|
||||
transcript_id=None, # Auto-generate
|
||||
event_id="evt_2",
|
||||
@@ -157,8 +207,19 @@ class TestTranscriptStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_transcript_backward(self, mock_db_engine):
|
||||
"""Test paging transcript backward (older items)."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id="conv_1",
|
||||
limit=10,
|
||||
@@ -170,8 +231,19 @@ class TestTranscriptStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_transcript_has_hard_limit(self, mock_db_engine):
|
||||
"""Test that transcript paging has a hard limit."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
# Request more than the hard limit
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id="conv_1",
|
||||
@@ -184,8 +256,19 @@ class TestTranscriptStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_transcript(self, mock_db_engine):
|
||||
"""Test searching transcript."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
items = await store.search_transcript(
|
||||
conversation_id="conv_1",
|
||||
query_text="database",
|
||||
@@ -259,10 +342,19 @@ class TestContextAccessPopulation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_has_history_apis_when_permitted(self, mock_db_engine):
|
||||
"""Test ContextAccess shows available APIs based on permissions."""
|
||||
# This would test the context builder logic
|
||||
# For now we verify the store methods work
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
cursor = await store.get_latest_cursor("conv_1")
|
||||
# Should return None or a cursor string
|
||||
assert cursor is None or isinstance(cursor, str)
|
||||
@@ -270,39 +362,224 @@ class TestContextAccessPopulation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_shows_has_history_before(self, mock_db_engine):
|
||||
"""Test ContextAccess indicates if history exists."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
store = TranscriptStore(mock_db_engine)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar.return_value = 0
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch.object(store, '_session_factory') as mock_factory:
|
||||
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||
|
||||
has_history = await store.has_history_before("conv_1", 10)
|
||||
assert isinstance(has_history, bool)
|
||||
|
||||
|
||||
class TestEventLogStoreRealSQLite:
|
||||
"""Test EventLogStore with real SQLite database."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy import text
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.entity.persistence.event_log import EventLog
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_get_event_round_trip(self, db_engine):
|
||||
"""Test append_event -> get_event round trip with real DB."""
|
||||
store = EventLogStore(db_engine)
|
||||
|
||||
# Append event
|
||||
event_id = await store.append_event(
|
||||
event_id="evt_real_001",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
bot_id="bot_001",
|
||||
conversation_id="conv_001",
|
||||
actor_type="user",
|
||||
actor_id="user_001",
|
||||
actor_name="Test User",
|
||||
input_summary="Hello world",
|
||||
run_id="run_001",
|
||||
runner_id="plugin:test/plugin/runner",
|
||||
)
|
||||
|
||||
assert event_id == "evt_real_001"
|
||||
|
||||
# Get event
|
||||
event = await store.get_event(event_id)
|
||||
assert event is not None
|
||||
assert event["event_id"] == "evt_real_001"
|
||||
assert event["event_type"] == "message.received"
|
||||
assert event["source"] == "platform"
|
||||
assert event["conversation_id"] == "conv_001"
|
||||
assert event["actor_type"] == "user"
|
||||
assert event["actor_id"] == "user_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_page_events(self, db_engine):
|
||||
"""Test page_events with real DB."""
|
||||
store = EventLogStore(db_engine)
|
||||
|
||||
# Append multiple events
|
||||
for i in range(5):
|
||||
await store.append_event(
|
||||
event_id=f"evt_real_{i:03d}",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_001",
|
||||
input_summary=f"Message {i}",
|
||||
)
|
||||
|
||||
# Page events
|
||||
items, next_seq, has_more = await store.page_events(
|
||||
conversation_id="conv_001",
|
||||
limit=3,
|
||||
)
|
||||
|
||||
assert len(items) == 3
|
||||
assert has_more is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_cursor(self, db_engine):
|
||||
"""Test get_latest_cursor with real DB."""
|
||||
store = EventLogStore(db_engine)
|
||||
|
||||
# Append events
|
||||
for i in range(3):
|
||||
await store.append_event(
|
||||
event_id=f"evt_cursor_{i:03d}",
|
||||
event_type="message.received",
|
||||
source="platform",
|
||||
conversation_id="conv_cursor",
|
||||
)
|
||||
|
||||
# Get latest cursor
|
||||
cursor = await store.get_latest_cursor("conv_cursor")
|
||||
assert cursor is not None
|
||||
assert int(cursor) > 0
|
||||
|
||||
|
||||
class TestTranscriptStoreRealSQLite:
|
||||
"""Test TranscriptStore with real SQLite database."""
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine(self):
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy import text
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.entity.persistence.transcript import Transcript
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_page_transcript_round_trip(self, db_engine):
|
||||
"""Test append_transcript -> page_transcript round trip with real DB."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
# Append transcript items
|
||||
for i in range(3):
|
||||
await store.append_transcript(
|
||||
transcript_id=f"trans_real_{i:03d}",
|
||||
event_id=f"evt_{i:03d}",
|
||||
conversation_id="conv_001",
|
||||
role="user" if i % 2 == 0 else "assistant",
|
||||
content=f"Message {i}",
|
||||
)
|
||||
|
||||
# Page transcript
|
||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||
conversation_id="conv_001",
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert len(items) == 3
|
||||
assert items[0]["conversation_id"] == "conv_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_transcript_real_db(self, db_engine):
|
||||
"""Test search_transcript with real DB."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
# Append transcript items
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_search_001",
|
||||
event_id="evt_search_001",
|
||||
conversation_id="conv_search",
|
||||
role="user",
|
||||
content="I want to learn about databases",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_search_002",
|
||||
event_id="evt_search_002",
|
||||
conversation_id="conv_search",
|
||||
role="assistant",
|
||||
content="Here is information about databases",
|
||||
)
|
||||
|
||||
# Search for "database"
|
||||
items = await store.search_transcript(
|
||||
conversation_id="conv_search",
|
||||
query_text="database",
|
||||
)
|
||||
|
||||
# Should find at least one match
|
||||
assert len(items) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_cursor_real_db(self, db_engine):
|
||||
"""Test get_latest_cursor with real DB."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
# Append transcript items
|
||||
for i in range(3):
|
||||
await store.append_transcript(
|
||||
transcript_id=f"trans_cursor_{i:03d}",
|
||||
event_id=f"evt_cursor_{i:03d}",
|
||||
conversation_id="conv_cursor",
|
||||
role="user",
|
||||
content=f"Message {i}",
|
||||
)
|
||||
|
||||
# Get latest cursor
|
||||
cursor = await store.get_latest_cursor("conv_cursor")
|
||||
assert cursor is not None
|
||||
assert int(cursor) > 0
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_engine():
|
||||
"""Create a mock database engine."""
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
"""Create a mock database engine for AsyncSession-based stores."""
|
||||
from unittest.mock import MagicMock
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
engine = MagicMock(spec=AsyncEngine)
|
||||
|
||||
# Mock connection
|
||||
mock_conn = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.fetchone.return_value = None
|
||||
mock_result.fetchall.return_value = []
|
||||
mock_result.scalar.return_value = 0
|
||||
mock_conn.execute = AsyncMock(return_value=mock_result)
|
||||
mock_conn.commit = AsyncMock()
|
||||
|
||||
# Create async context manager for connect()
|
||||
class AsyncConnectContextManager:
|
||||
async def __aenter__(self):
|
||||
return mock_conn
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
# connect() should return an async context manager
|
||||
engine.connect = MagicMock(return_value=AsyncConnectContextManager())
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user