mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat(agent-runner): add artifact store pull APIs
This commit is contained in:
300
src/langbot/pkg/agent/runner/artifact_store.py
Normal file
300
src/langbot/pkg/agent/runner/artifact_store.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""Artifact store for managing Host-owned artifacts."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
import typing
|
||||||
|
import uuid
|
||||||
|
import base64
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from ...entity.persistence.artifact import AgentArtifact
|
||||||
|
from ...entity.persistence.bstorage import BinaryStorage
|
||||||
|
|
||||||
|
|
||||||
|
class ArtifactStore:
|
||||||
|
"""Store for AgentArtifact records.
|
||||||
|
|
||||||
|
Handles artifact metadata registration and content retrieval.
|
||||||
|
Actual blob storage is delegated to BinaryStorage or external storage.
|
||||||
|
|
||||||
|
All methods are async and use the provided database engine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
engine: AsyncEngine
|
||||||
|
|
||||||
|
# Hard limits
|
||||||
|
MAX_INLINE_READ_BYTES = 1024 * 1024 # 1MB max for inline base64
|
||||||
|
MAX_RANGE_READ_BYTES = 10 * 1024 * 1024 # 10MB max for range reads
|
||||||
|
|
||||||
|
def __init__(self, engine: AsyncEngine):
|
||||||
|
self.engine = engine
|
||||||
|
self._session_factory = sessionmaker(
|
||||||
|
engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
async def register_artifact(
|
||||||
|
self,
|
||||||
|
artifact_id: str | None,
|
||||||
|
artifact_type: str,
|
||||||
|
source: str,
|
||||||
|
storage_key: str | None = None,
|
||||||
|
storage_type: str = 'binary_storage',
|
||||||
|
mime_type: str | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
size_bytes: int | None = None,
|
||||||
|
sha256: str | None = None,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
run_id: str | None = None,
|
||||||
|
runner_id: str | None = None,
|
||||||
|
bot_id: str | None = None,
|
||||||
|
workspace_id: str | None = None,
|
||||||
|
expires_at: datetime.datetime | None = None,
|
||||||
|
metadata: dict[str, typing.Any] | None = None,
|
||||||
|
content: bytes | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Register a new artifact.
|
||||||
|
|
||||||
|
If content is provided and storage_key is None, stores content
|
||||||
|
in BinaryStorage automatically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact_id: Unique artifact ID (generated if None)
|
||||||
|
artifact_type: Type of artifact (image, file, voice, tool_result, etc.)
|
||||||
|
source: Source of artifact (platform, runner, tool, system)
|
||||||
|
storage_key: Key in BinaryStorage or external reference
|
||||||
|
storage_type: Storage type (binary_storage, file, url)
|
||||||
|
mime_type: MIME type
|
||||||
|
name: Original file name
|
||||||
|
size_bytes: Size in bytes
|
||||||
|
sha256: SHA256 hash
|
||||||
|
conversation_id: Conversation ID
|
||||||
|
run_id: Run ID that created this
|
||||||
|
runner_id: Runner ID that created this
|
||||||
|
bot_id: Bot UUID
|
||||||
|
workspace_id: Workspace ID
|
||||||
|
expires_at: Expiration time
|
||||||
|
metadata: Additional metadata
|
||||||
|
content: Optional content to store in BinaryStorage
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The artifact_id
|
||||||
|
"""
|
||||||
|
if artifact_id is None:
|
||||||
|
artifact_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# If content provided, store in BinaryStorage
|
||||||
|
if content is not None and storage_key is None:
|
||||||
|
storage_key = f"artifact:{artifact_id}"
|
||||||
|
storage_type = 'binary_storage'
|
||||||
|
if size_bytes is None:
|
||||||
|
size_bytes = len(content)
|
||||||
|
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
# Store content in BinaryStorage if provided
|
||||||
|
if content is not None:
|
||||||
|
binary_storage = BinaryStorage(
|
||||||
|
unique_key=f'artifact:{artifact_id}',
|
||||||
|
key=storage_key,
|
||||||
|
owner_type='artifact',
|
||||||
|
owner='host',
|
||||||
|
value=content,
|
||||||
|
)
|
||||||
|
session.add(binary_storage)
|
||||||
|
|
||||||
|
# Store artifact metadata
|
||||||
|
artifact = AgentArtifact(
|
||||||
|
artifact_id=artifact_id,
|
||||||
|
artifact_type=artifact_type,
|
||||||
|
mime_type=mime_type,
|
||||||
|
name=name,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
sha256=sha256,
|
||||||
|
source=source,
|
||||||
|
storage_key=storage_key,
|
||||||
|
storage_type=storage_type,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
run_id=run_id,
|
||||||
|
runner_id=runner_id,
|
||||||
|
bot_id=bot_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
created_at=datetime.datetime.utcnow(),
|
||||||
|
expires_at=expires_at,
|
||||||
|
metadata_json=json.dumps(metadata) if metadata else None,
|
||||||
|
)
|
||||||
|
session.add(artifact)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return artifact_id
|
||||||
|
|
||||||
|
async def get_metadata(
|
||||||
|
self,
|
||||||
|
artifact_id: str,
|
||||||
|
) -> dict[str, typing.Any] | None:
|
||||||
|
"""Get artifact metadata (public fields only, no internal storage info).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact_id: Artifact ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Artifact metadata dict compatible with SDK ArtifactMetadata, or None if not found
|
||||||
|
"""
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
sqlalchemy.select(AgentArtifact).where(
|
||||||
|
AgentArtifact.artifact_id == artifact_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalars().first()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return self._row_to_public_dict(row)
|
||||||
|
|
||||||
|
async def _get_internal_record(
|
||||||
|
self,
|
||||||
|
artifact_id: str,
|
||||||
|
) -> AgentArtifact | None:
|
||||||
|
"""Get full artifact record including internal fields.
|
||||||
|
|
||||||
|
Used internally by read_artifact to access storage_key/storage_type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact_id: Artifact ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentArtifact ORM instance, or None if not found
|
||||||
|
"""
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
sqlalchemy.select(AgentArtifact).where(
|
||||||
|
AgentArtifact.artifact_id == artifact_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def read_artifact(
|
||||||
|
self,
|
||||||
|
artifact_id: str,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> dict[str, typing.Any] | None:
|
||||||
|
"""Read artifact content.
|
||||||
|
|
||||||
|
For small artifacts, returns content_base64 directly.
|
||||||
|
For large artifacts, returns file_key for chunked transfer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact_id: Artifact ID
|
||||||
|
offset: Byte offset to start reading from (must be >= 0)
|
||||||
|
limit: Maximum bytes to read (must be > 0 if provided)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ArtifactReadResult dict, or None if not found
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If offset < 0 or limit <= 0
|
||||||
|
"""
|
||||||
|
# Validate offset and limit
|
||||||
|
if offset < 0:
|
||||||
|
raise ValueError("offset must be >= 0")
|
||||||
|
|
||||||
|
if limit is not None and limit <= 0:
|
||||||
|
raise ValueError("limit must be > 0")
|
||||||
|
|
||||||
|
# Get internal record (includes storage_key/storage_type)
|
||||||
|
record = await self._get_internal_record(artifact_id)
|
||||||
|
if record is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
storage_type = record.storage_type or 'binary_storage'
|
||||||
|
storage_key = record.storage_key
|
||||||
|
size_bytes = record.size_bytes or 0
|
||||||
|
|
||||||
|
# Cap limit at hard limit
|
||||||
|
if limit is None:
|
||||||
|
limit = self.MAX_INLINE_READ_BYTES
|
||||||
|
limit = min(limit, self.MAX_RANGE_READ_BYTES)
|
||||||
|
|
||||||
|
# For binary_storage, read content
|
||||||
|
if storage_type == 'binary_storage' and storage_key:
|
||||||
|
content = await self._read_binary_storage(storage_key)
|
||||||
|
if content is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Apply offset and limit
|
||||||
|
if offset > 0:
|
||||||
|
content = content[offset:]
|
||||||
|
if limit and len(content) > limit:
|
||||||
|
content = content[:limit]
|
||||||
|
has_more = True
|
||||||
|
else:
|
||||||
|
has_more = False
|
||||||
|
|
||||||
|
return {
|
||||||
|
'artifact_id': artifact_id,
|
||||||
|
'mime_type': record.mime_type,
|
||||||
|
'size_bytes': size_bytes,
|
||||||
|
'offset': offset,
|
||||||
|
'length': len(content),
|
||||||
|
'content_base64': base64.b64encode(content).decode('utf-8'),
|
||||||
|
'file_key': None,
|
||||||
|
'has_more': has_more,
|
||||||
|
}
|
||||||
|
|
||||||
|
# For other storage types, return storage reference
|
||||||
|
# (caller can use file_key for chunked transfer)
|
||||||
|
return {
|
||||||
|
'artifact_id': artifact_id,
|
||||||
|
'mime_type': record.mime_type,
|
||||||
|
'size_bytes': size_bytes,
|
||||||
|
'offset': offset,
|
||||||
|
'length': None,
|
||||||
|
'content_base64': None,
|
||||||
|
'file_key': storage_key,
|
||||||
|
'has_more': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _read_binary_storage(self, key: str) -> bytes | None:
|
||||||
|
"""Read content from BinaryStorage.
|
||||||
|
|
||||||
|
Uses unique_key for isolation to prevent cross-artifact access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The unique_key used when storing the artifact
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content bytes, or None if not found
|
||||||
|
"""
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
sqlalchemy.select(BinaryStorage).where(BinaryStorage.unique_key == key)
|
||||||
|
)
|
||||||
|
row = result.scalars().first()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return row.value
|
||||||
|
|
||||||
|
def _row_to_public_dict(self, row: AgentArtifact) -> dict[str, typing.Any]:
|
||||||
|
"""Convert an AgentArtifact row to public dict.
|
||||||
|
|
||||||
|
Returns only fields that match SDK ArtifactMetadata entity.
|
||||||
|
Host-only fields (bot_id, workspace_id, storage_key, storage_type) are excluded.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'artifact_id': row.artifact_id,
|
||||||
|
'artifact_type': row.artifact_type,
|
||||||
|
'mime_type': row.mime_type,
|
||||||
|
'name': row.name,
|
||||||
|
'size_bytes': row.size_bytes,
|
||||||
|
'sha256': row.sha256,
|
||||||
|
'source': row.source,
|
||||||
|
'conversation_id': row.conversation_id,
|
||||||
|
'run_id': row.run_id,
|
||||||
|
'runner_id': row.runner_id,
|
||||||
|
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
|
||||||
|
'expires_at': int(row.expires_at.timestamp()) if row.expires_at else None,
|
||||||
|
'metadata': json.loads(row.metadata_json) if row.metadata_json else {},
|
||||||
|
}
|
||||||
@@ -891,11 +891,14 @@ class AgentRunContextBuilder:
|
|||||||
permissions = descriptor.permissions or {}
|
permissions = descriptor.permissions or {}
|
||||||
history_permissions = permissions.get('history', [])
|
history_permissions = permissions.get('history', [])
|
||||||
event_permissions = permissions.get('events', [])
|
event_permissions = permissions.get('events', [])
|
||||||
|
artifact_permissions = permissions.get('artifacts', [])
|
||||||
|
|
||||||
history_page_enabled = 'page' in history_permissions and conversation_id is not None
|
history_page_enabled = 'page' in history_permissions and conversation_id is not None
|
||||||
history_search_enabled = 'search' in history_permissions and conversation_id is not None
|
history_search_enabled = 'search' in history_permissions and conversation_id is not None
|
||||||
event_get_enabled = 'get' in event_permissions
|
event_get_enabled = 'get' in event_permissions
|
||||||
event_page_enabled = 'page' in event_permissions and conversation_id is not None
|
event_page_enabled = 'page' in event_permissions and conversation_id is not None
|
||||||
|
artifact_metadata_enabled = 'metadata' in artifact_permissions
|
||||||
|
artifact_read_enabled = 'read' in artifact_permissions
|
||||||
|
|
||||||
# Get latest cursor and has_history_before if conversation exists
|
# Get latest cursor and has_history_before if conversation exists
|
||||||
latest_cursor = None
|
latest_cursor = None
|
||||||
@@ -931,8 +934,8 @@ class AgentRunContextBuilder:
|
|||||||
'history_search': history_search_enabled,
|
'history_search': history_search_enabled,
|
||||||
'event_get': event_get_enabled,
|
'event_get': event_get_enabled,
|
||||||
'event_page': event_page_enabled,
|
'event_page': event_page_enabled,
|
||||||
'artifact_metadata': False, # TODO: Implement artifact store
|
'artifact_metadata': artifact_metadata_enabled,
|
||||||
'artifact_read': False,
|
'artifact_read': artifact_read_enabled,
|
||||||
'state': True,
|
'state': True,
|
||||||
'storage': True,
|
'storage': True,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ import typing
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from ...entity.persistence.event_log import EventLog
|
from ...entity.persistence.event_log import EventLog
|
||||||
from ...entity.persistence.transcript import Transcript
|
from ...entity.persistence.transcript import Transcript
|
||||||
@@ -27,6 +28,9 @@ class EventLogStore:
|
|||||||
|
|
||||||
def __init__(self, engine: AsyncEngine):
|
def __init__(self, engine: AsyncEngine):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
self._session_factory = sessionmaker(
|
||||||
|
engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
async def append_event(
|
async def append_event(
|
||||||
self,
|
self,
|
||||||
@@ -83,32 +87,31 @@ class EventLogStore:
|
|||||||
if input_summary and len(input_summary) > self.MAX_INPUT_SUMMARY_LENGTH:
|
if input_summary and len(input_summary) > self.MAX_INPUT_SUMMARY_LENGTH:
|
||||||
input_summary = input_summary[:self.MAX_INPUT_SUMMARY_LENGTH - 3] + "..."
|
input_summary = input_summary[:self.MAX_INPUT_SUMMARY_LENGTH - 3] + "..."
|
||||||
|
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
await conn.execute(
|
event = EventLog(
|
||||||
sqlalchemy.insert(EventLog).values(
|
event_id=event_id,
|
||||||
event_id=event_id,
|
event_type=event_type,
|
||||||
event_type=event_type,
|
event_time=event_time,
|
||||||
event_time=event_time,
|
source=source,
|
||||||
source=source,
|
bot_id=bot_id,
|
||||||
bot_id=bot_id,
|
workspace_id=workspace_id,
|
||||||
workspace_id=workspace_id,
|
conversation_id=conversation_id,
|
||||||
conversation_id=conversation_id,
|
thread_id=thread_id,
|
||||||
thread_id=thread_id,
|
actor_type=actor_type,
|
||||||
actor_type=actor_type,
|
actor_id=actor_id,
|
||||||
actor_id=actor_id,
|
actor_name=actor_name,
|
||||||
actor_name=actor_name,
|
subject_type=subject_type,
|
||||||
subject_type=subject_type,
|
subject_id=subject_id,
|
||||||
subject_id=subject_id,
|
input_summary=input_summary,
|
||||||
input_summary=input_summary,
|
input_json=json.dumps(input_json) if input_json else None,
|
||||||
input_json=json.dumps(input_json) if input_json else None,
|
raw_ref=raw_ref,
|
||||||
raw_ref=raw_ref,
|
run_id=run_id,
|
||||||
run_id=run_id,
|
runner_id=runner_id,
|
||||||
runner_id=runner_id,
|
metadata_json=json.dumps(metadata) if metadata else None,
|
||||||
metadata_json=json.dumps(metadata) if metadata else None,
|
created_at=datetime.datetime.utcnow(),
|
||||||
created_at=datetime.datetime.utcnow(),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
await conn.commit()
|
session.add(event)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
return event_id
|
return event_id
|
||||||
|
|
||||||
@@ -124,14 +127,14 @@ class EventLogStore:
|
|||||||
Returns:
|
Returns:
|
||||||
Event record as dict, or None if not found
|
Event record as dict, or None if not found
|
||||||
"""
|
"""
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
result = await conn.execute(
|
result = await session.execute(
|
||||||
sqlalchemy.select(EventLog).where(EventLog.event_id == event_id)
|
sqlalchemy.select(EventLog).where(EventLog.event_id == event_id)
|
||||||
)
|
)
|
||||||
row = result.fetchone()
|
row = result.scalars().first()
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row[0])
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
async def page_events(
|
async def page_events(
|
||||||
self,
|
self,
|
||||||
@@ -153,7 +156,7 @@ class EventLogStore:
|
|||||||
"""
|
"""
|
||||||
limit = min(limit, 100) # Hard cap
|
limit = min(limit, 100) # Hard cap
|
||||||
|
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
query = sqlalchemy.select(EventLog)
|
query = sqlalchemy.select(EventLog)
|
||||||
|
|
||||||
if conversation_id is not None:
|
if conversation_id is not None:
|
||||||
@@ -167,10 +170,10 @@ class EventLogStore:
|
|||||||
|
|
||||||
query = query.order_by(EventLog.id.desc()).limit(limit + 1)
|
query = query.order_by(EventLog.id.desc()).limit(limit + 1)
|
||||||
|
|
||||||
result = await conn.execute(query)
|
result = await session.execute(query)
|
||||||
rows = result.fetchall()
|
rows = result.scalars().all()
|
||||||
|
|
||||||
items = [self._row_to_dict(row[0]) for row in rows[:limit]]
|
items = [self._row_to_dict(row) for row in rows[:limit]]
|
||||||
has_more = len(rows) > limit
|
has_more = len(rows) > limit
|
||||||
next_seq = items[-1]['id'] if items and has_more else None
|
next_seq = items[-1]['id'] if items and has_more else None
|
||||||
|
|
||||||
@@ -188,17 +191,17 @@ class EventLogStore:
|
|||||||
Returns:
|
Returns:
|
||||||
Cursor string (seq number), or None if no events
|
Cursor string (seq number), or None if no events
|
||||||
"""
|
"""
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
result = await conn.execute(
|
result = await session.execute(
|
||||||
sqlalchemy.select(EventLog.id)
|
sqlalchemy.select(EventLog.id)
|
||||||
.where(EventLog.conversation_id == conversation_id)
|
.where(EventLog.conversation_id == conversation_id)
|
||||||
.order_by(EventLog.id.desc())
|
.order_by(EventLog.id.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
row = result.fetchone()
|
row = result.scalars().first()
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
return str(row[0])
|
return str(row)
|
||||||
|
|
||||||
async def has_events_before(
|
async def has_events_before(
|
||||||
self,
|
self,
|
||||||
@@ -214,8 +217,8 @@ class EventLogStore:
|
|||||||
Returns:
|
Returns:
|
||||||
True if there are events before
|
True if there are events before
|
||||||
"""
|
"""
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
result = await conn.execute(
|
result = await session.execute(
|
||||||
sqlalchemy.select(sqlalchemy.func.count())
|
sqlalchemy.select(sqlalchemy.func.count())
|
||||||
.select_from(EventLog)
|
.select_from(EventLog)
|
||||||
.where(
|
.where(
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ class AgentRunOrchestrator:
|
|||||||
query_id=None, # No query_id in event-first mode
|
query_id=None, # No query_id in event-first mode
|
||||||
plugin_identity=descriptor.get_plugin_id(),
|
plugin_identity=descriptor.get_plugin_id(),
|
||||||
resources=resources,
|
resources=resources,
|
||||||
|
permissions=descriptor.permissions or {},
|
||||||
conversation_id=event.conversation_id,
|
conversation_id=event.conversation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -222,6 +223,7 @@ class AgentRunOrchestrator:
|
|||||||
query_id=query.query_id,
|
query_id=query.query_id,
|
||||||
plugin_identity=descriptor.get_plugin_id(),
|
plugin_identity=descriptor.get_plugin_id(),
|
||||||
resources=resources,
|
resources=resources,
|
||||||
|
permissions=descriptor.permissions or {},
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class AgentRunSession(typing.TypedDict):
|
|||||||
plugin_identity: Plugin identifier (author/name) of the runner
|
plugin_identity: Plugin identifier (author/name) of the runner
|
||||||
conversation_id: Conversation ID for history/event access
|
conversation_id: Conversation ID for history/event access
|
||||||
resources: Authorized resources for this run (from AgentResources)
|
resources: Authorized resources for this run (from AgentResources)
|
||||||
|
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
|
||||||
status: Session status tracking
|
status: Session status tracking
|
||||||
_authorized_ids: Pre-computed authorized resource IDs for O(1) lookup
|
_authorized_ids: Pre-computed authorized resource IDs for O(1) lookup
|
||||||
"""
|
"""
|
||||||
@@ -36,6 +37,7 @@ class AgentRunSession(typing.TypedDict):
|
|||||||
plugin_identity: str # author/name
|
plugin_identity: str # author/name
|
||||||
conversation_id: str | None
|
conversation_id: str | None
|
||||||
resources: AgentResources
|
resources: AgentResources
|
||||||
|
permissions: dict[str, list[str]]
|
||||||
status: AgentRunSessionStatus
|
status: AgentRunSessionStatus
|
||||||
_authorized_ids: dict[str, set[str]] # Pre-computed sets for O(1) lookup
|
_authorized_ids: dict[str, set[str]] # Pre-computed sets for O(1) lookup
|
||||||
|
|
||||||
@@ -67,6 +69,7 @@ class AgentRunSessionRegistry:
|
|||||||
plugin_identity: str,
|
plugin_identity: str,
|
||||||
resources: AgentResources,
|
resources: AgentResources,
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
|
permissions: dict[str, list[str]] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a new agent run session.
|
"""Register a new agent run session.
|
||||||
|
|
||||||
@@ -77,9 +80,13 @@ class AgentRunSessionRegistry:
|
|||||||
plugin_identity: Plugin identifier (author/name)
|
plugin_identity: Plugin identifier (author/name)
|
||||||
resources: Authorized resources for this run
|
resources: Authorized resources for this run
|
||||||
conversation_id: Conversation ID for history/event access
|
conversation_id: Conversation ID for history/event access
|
||||||
|
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
|
||||||
"""
|
"""
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
|
|
||||||
|
# Normalize permissions to empty dict if None
|
||||||
|
permissions = permissions or {}
|
||||||
|
|
||||||
# Pre-compute authorized resource IDs for O(1) lookup
|
# Pre-compute authorized resource IDs for O(1) lookup
|
||||||
authorized_ids: dict[str, set[str]] = {
|
authorized_ids: dict[str, set[str]] = {
|
||||||
'model': {m.get('model_id') for m in resources.get('models', [])},
|
'model': {m.get('model_id') for m in resources.get('models', [])},
|
||||||
@@ -95,6 +102,7 @@ class AgentRunSessionRegistry:
|
|||||||
'plugin_identity': plugin_identity,
|
'plugin_identity': plugin_identity,
|
||||||
'conversation_id': conversation_id,
|
'conversation_id': conversation_id,
|
||||||
'resources': resources,
|
'resources': resources,
|
||||||
|
'permissions': permissions,
|
||||||
'status': {
|
'status': {
|
||||||
'started_at': now,
|
'started_at': now,
|
||||||
'last_activity_at': now,
|
'last_activity_at': now,
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ import typing
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from ...entity.persistence.transcript import Transcript
|
from ...entity.persistence.transcript import Transcript
|
||||||
|
|
||||||
@@ -27,6 +28,9 @@ class TranscriptStore:
|
|||||||
|
|
||||||
def __init__(self, engine: AsyncEngine):
|
def __init__(self, engine: AsyncEngine):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
self._session_factory = sessionmaker(
|
||||||
|
engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
async def append_transcript(
|
async def append_transcript(
|
||||||
self,
|
self,
|
||||||
@@ -72,26 +76,25 @@ class TranscriptStore:
|
|||||||
# Get next sequence number for this conversation
|
# Get next sequence number for this conversation
|
||||||
seq = await self._get_next_seq(conversation_id)
|
seq = await self._get_next_seq(conversation_id)
|
||||||
|
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
await conn.execute(
|
item = Transcript(
|
||||||
sqlalchemy.insert(Transcript).values(
|
transcript_id=transcript_id,
|
||||||
transcript_id=transcript_id,
|
event_id=event_id,
|
||||||
event_id=event_id,
|
conversation_id=conversation_id,
|
||||||
conversation_id=conversation_id,
|
thread_id=thread_id,
|
||||||
thread_id=thread_id,
|
role=role,
|
||||||
role=role,
|
item_type=item_type,
|
||||||
item_type=item_type,
|
content=content,
|
||||||
content=content,
|
content_json=json.dumps(content_json) if content_json else None,
|
||||||
content_json=json.dumps(content_json) if content_json else None,
|
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
|
||||||
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
|
seq=seq,
|
||||||
seq=seq,
|
run_id=run_id,
|
||||||
run_id=run_id,
|
runner_id=runner_id,
|
||||||
runner_id=runner_id,
|
created_at=datetime.datetime.utcnow(),
|
||||||
created_at=datetime.datetime.utcnow(),
|
metadata_json=json.dumps(metadata) if metadata else None,
|
||||||
metadata_json=json.dumps(metadata) if metadata else None,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
await conn.commit()
|
session.add(item)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
return transcript_id
|
return transcript_id
|
||||||
|
|
||||||
@@ -119,7 +122,7 @@ class TranscriptStore:
|
|||||||
"""
|
"""
|
||||||
limit = min(limit, self.HARD_LIMIT)
|
limit = min(limit, self.HARD_LIMIT)
|
||||||
|
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
query = sqlalchemy.select(Transcript).where(
|
query = sqlalchemy.select(Transcript).where(
|
||||||
Transcript.conversation_id == conversation_id
|
Transcript.conversation_id == conversation_id
|
||||||
)
|
)
|
||||||
@@ -136,10 +139,10 @@ class TranscriptStore:
|
|||||||
|
|
||||||
query = query.limit(limit + 1)
|
query = query.limit(limit + 1)
|
||||||
|
|
||||||
result = await conn.execute(query)
|
result = await session.execute(query)
|
||||||
rows = result.fetchall()
|
rows = result.scalars().all()
|
||||||
|
|
||||||
items = [self._row_to_dict(row[0], include_artifacts) for row in rows[:limit]]
|
items = [self._row_to_dict(row, include_artifacts) for row in rows[:limit]]
|
||||||
has_more = len(rows) > limit
|
has_more = len(rows) > limit
|
||||||
|
|
||||||
# Calculate cursors
|
# Calculate cursors
|
||||||
@@ -179,7 +182,7 @@ class TranscriptStore:
|
|||||||
Returns:
|
Returns:
|
||||||
List of matching items
|
List of matching items
|
||||||
"""
|
"""
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
query = sqlalchemy.select(Transcript).where(
|
query = sqlalchemy.select(Transcript).where(
|
||||||
Transcript.conversation_id == conversation_id,
|
Transcript.conversation_id == conversation_id,
|
||||||
Transcript.content.ilike(f"%{query_text}%"),
|
Transcript.content.ilike(f"%{query_text}%"),
|
||||||
@@ -194,10 +197,10 @@ class TranscriptStore:
|
|||||||
|
|
||||||
query = query.order_by(Transcript.seq.desc()).limit(top_k)
|
query = query.order_by(Transcript.seq.desc()).limit(top_k)
|
||||||
|
|
||||||
result = await conn.execute(query)
|
result = await session.execute(query)
|
||||||
rows = result.fetchall()
|
rows = result.scalars().all()
|
||||||
|
|
||||||
return [self._row_to_dict(row[0], include_artifacts=True) for row in rows]
|
return [self._row_to_dict(row, include_artifacts=True) for row in rows]
|
||||||
|
|
||||||
async def get_latest_cursor(
|
async def get_latest_cursor(
|
||||||
self,
|
self,
|
||||||
@@ -211,17 +214,17 @@ class TranscriptStore:
|
|||||||
Returns:
|
Returns:
|
||||||
Cursor string (seq number), or None if no items
|
Cursor string (seq number), or None if no items
|
||||||
"""
|
"""
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
result = await conn.execute(
|
result = await session.execute(
|
||||||
sqlalchemy.select(Transcript.seq)
|
sqlalchemy.select(Transcript.seq)
|
||||||
.where(Transcript.conversation_id == conversation_id)
|
.where(Transcript.conversation_id == conversation_id)
|
||||||
.order_by(Transcript.seq.desc())
|
.order_by(Transcript.seq.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
row = result.fetchone()
|
row = result.scalars().first()
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
return str(row[0])
|
return str(row)
|
||||||
|
|
||||||
async def has_history_before(
|
async def has_history_before(
|
||||||
self,
|
self,
|
||||||
@@ -237,8 +240,8 @@ class TranscriptStore:
|
|||||||
Returns:
|
Returns:
|
||||||
True if there are items before
|
True if there are items before
|
||||||
"""
|
"""
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
result = await conn.execute(
|
result = await session.execute(
|
||||||
sqlalchemy.select(sqlalchemy.func.count())
|
sqlalchemy.select(sqlalchemy.func.count())
|
||||||
.select_from(Transcript)
|
.select_from(Transcript)
|
||||||
.where(
|
.where(
|
||||||
@@ -251,8 +254,8 @@ class TranscriptStore:
|
|||||||
|
|
||||||
async def _get_next_seq(self, conversation_id: str) -> int:
|
async def _get_next_seq(self, conversation_id: str) -> int:
|
||||||
"""Get the next sequence number for a conversation."""
|
"""Get the next sequence number for a conversation."""
|
||||||
async with self.engine.connect() as conn:
|
async with self._session_factory() as session:
|
||||||
result = await conn.execute(
|
result = await session.execute(
|
||||||
sqlalchemy.select(sqlalchemy.func.max(Transcript.seq))
|
sqlalchemy.select(sqlalchemy.func.max(Transcript.seq))
|
||||||
.where(Transcript.conversation_id == conversation_id)
|
.where(Transcript.conversation_id == conversation_id)
|
||||||
)
|
)
|
||||||
|
|||||||
77
src/langbot/pkg/entity/persistence/artifact.py
Normal file
77
src/langbot/pkg/entity/persistence/artifact.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Artifact persistence entity for Host-owned artifact store."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from .base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class AgentArtifact(Base):
|
||||||
|
"""AgentArtifact stores metadata for large files, images, tool results, etc.
|
||||||
|
|
||||||
|
This table only stores metadata. The actual blob content is stored in
|
||||||
|
BinaryStorage or external storage, referenced by storage_key.
|
||||||
|
|
||||||
|
Artifacts are accessed via artifact_metadata and artifact_read APIs
|
||||||
|
with run_id authorization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = 'agent_artifact'
|
||||||
|
|
||||||
|
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, autoincrement=True)
|
||||||
|
"""Auto-increment ID for sequencing."""
|
||||||
|
|
||||||
|
artifact_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, unique=True, index=True)
|
||||||
|
"""Unique artifact identifier."""
|
||||||
|
|
||||||
|
artifact_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False)
|
||||||
|
"""Artifact type: 'image', 'file', 'voice', 'tool_result', 'platform_attachment', etc."""
|
||||||
|
|
||||||
|
mime_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
|
"""MIME type of the content."""
|
||||||
|
|
||||||
|
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
|
"""Original file name (if applicable)."""
|
||||||
|
|
||||||
|
size_bytes = sqlalchemy.Column(sqlalchemy.BigInteger, nullable=True)
|
||||||
|
"""Size in bytes."""
|
||||||
|
|
||||||
|
sha256 = sqlalchemy.Column(sqlalchemy.String(64), nullable=True)
|
||||||
|
"""SHA256 hash of content (for integrity verification)."""
|
||||||
|
|
||||||
|
source = sqlalchemy.Column(sqlalchemy.String(50), nullable=False)
|
||||||
|
"""Source of artifact: 'platform', 'runner', 'tool', 'system'."""
|
||||||
|
|
||||||
|
# Storage reference (points to BinaryStorage or external storage)
|
||||||
|
storage_key = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
|
"""Key in BinaryStorage or external storage reference."""
|
||||||
|
|
||||||
|
storage_type = sqlalchemy.Column(sqlalchemy.String(50), nullable=False, default='binary_storage')
|
||||||
|
"""Storage type: 'binary_storage', 'file', 'url', etc."""
|
||||||
|
|
||||||
|
# Context
|
||||||
|
conversation_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||||
|
"""Conversation this artifact belongs to."""
|
||||||
|
|
||||||
|
run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||||
|
"""Run ID that created this artifact."""
|
||||||
|
|
||||||
|
runner_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
|
"""Runner ID that created this artifact."""
|
||||||
|
|
||||||
|
bot_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
|
"""Bot UUID that handled this artifact."""
|
||||||
|
|
||||||
|
workspace_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||||
|
"""Workspace ID for multi-tenant deployments."""
|
||||||
|
|
||||||
|
# Lifecycle
|
||||||
|
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, default=datetime.datetime.utcnow)
|
||||||
|
"""When this artifact was created."""
|
||||||
|
|
||||||
|
expires_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True)
|
||||||
|
"""When this artifact expires (optional)."""
|
||||||
|
|
||||||
|
metadata_json = sqlalchemy.Column(sqlalchemy.Text, nullable=True)
|
||||||
|
"""Additional metadata as JSON string."""
|
||||||
@@ -17,6 +17,7 @@ from langbot.pkg.entity.persistence.base import Base
|
|||||||
# This is required for autogenerate to detect model changes
|
# This is required for autogenerate to detect model changes
|
||||||
from langbot.pkg.entity.persistence import (
|
from langbot.pkg.entity.persistence import (
|
||||||
apikey,
|
apikey,
|
||||||
|
artifact,
|
||||||
bot,
|
bot,
|
||||||
bstorage,
|
bstorage,
|
||||||
event_log,
|
event_log,
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
"""add_agent_artifact_table
|
||||||
|
|
||||||
|
Revision ID: a1b2c3d4e5f6
|
||||||
|
Revises: 58846a8d7a81
|
||||||
|
Create Date: 2026-05-23 20:00:00.000000
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers
|
||||||
|
revision = 'a1b2c3d4e5f6'
|
||||||
|
down_revision = '58846a8d7a81'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Create agent_artifact table
|
||||||
|
op.create_table(
|
||||||
|
'agent_artifact',
|
||||||
|
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
|
||||||
|
sa.Column('artifact_id', sa.String(255), nullable=False, unique=True),
|
||||||
|
sa.Column('artifact_type', sa.String(50), nullable=False),
|
||||||
|
sa.Column('mime_type', sa.String(255), nullable=True),
|
||||||
|
sa.Column('name', sa.String(255), nullable=True),
|
||||||
|
sa.Column('size_bytes', sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column('sha256', sa.String(64), nullable=True),
|
||||||
|
sa.Column('source', sa.String(50), nullable=False),
|
||||||
|
sa.Column('storage_key', sa.String(255), nullable=True),
|
||||||
|
sa.Column('storage_type', sa.String(50), nullable=False, server_default='binary_storage'),
|
||||||
|
sa.Column('conversation_id', sa.String(255), nullable=True),
|
||||||
|
sa.Column('run_id', sa.String(255), nullable=True),
|
||||||
|
sa.Column('runner_id', sa.String(255), nullable=True),
|
||||||
|
sa.Column('bot_id', sa.String(255), nullable=True),
|
||||||
|
sa.Column('workspace_id', sa.String(255), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.text('(CURRENT_TIMESTAMP)')),
|
||||||
|
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('metadata_json', sa.Text(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes for agent_artifact
|
||||||
|
with op.batch_alter_table('agent_artifact', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('ix_agent_artifact_artifact_id', ['artifact_id'], unique=True)
|
||||||
|
batch_op.create_index('ix_agent_artifact_conversation_id', ['conversation_id'], unique=False)
|
||||||
|
batch_op.create_index('ix_agent_artifact_run_id', ['run_id'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop agent_artifact table
|
||||||
|
with op.batch_alter_table('agent_artifact', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('ix_agent_artifact_run_id')
|
||||||
|
batch_op.drop_index('ix_agent_artifact_conversation_id')
|
||||||
|
batch_op.drop_index('ix_agent_artifact_artifact_id')
|
||||||
|
|
||||||
|
op.drop_table('agent_artifact')
|
||||||
@@ -100,6 +100,47 @@ def _build_tool_detail(tool: Any, requested_tool_name: str | None = None) -> dic
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_artifact_access(
|
||||||
|
session: dict[str, Any],
|
||||||
|
artifact_metadata: dict[str, Any],
|
||||||
|
operation: str,
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
"""Validate artifact access for a run session.
|
||||||
|
|
||||||
|
Authorization rules (evaluated in order, first match wins):
|
||||||
|
1. Artifact run_id matches session run_id → ALLOW (created by this run)
|
||||||
|
2. Artifact has conversation_id AND matches session conversation_id → ALLOW (same conversation)
|
||||||
|
3. Otherwise → DENY
|
||||||
|
|
||||||
|
Note: Artifacts without conversation_id are NOT globally accessible by default.
|
||||||
|
Without an explicit scope field, we enforce strict access control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AgentRunSession dict with run_id, conversation_id, permissions
|
||||||
|
artifact_metadata: Artifact metadata dict with conversation_id, run_id
|
||||||
|
operation: Operation name for error messages ('metadata' or 'read')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_allowed, error_message). If is_allowed is False, error_message contains reason.
|
||||||
|
"""
|
||||||
|
artifact_conversation_id = artifact_metadata.get('conversation_id')
|
||||||
|
artifact_run_id = artifact_metadata.get('run_id')
|
||||||
|
session_conversation_id = session.get('conversation_id')
|
||||||
|
session_run_id = session.get('run_id')
|
||||||
|
|
||||||
|
# Rule 1: Created by this run (allows cross-conversation access for self-created artifacts)
|
||||||
|
if artifact_run_id and artifact_run_id == session_run_id:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
# Rule 2: Same conversation (requires artifact to have conversation_id)
|
||||||
|
if artifact_conversation_id and session_conversation_id:
|
||||||
|
if artifact_conversation_id == session_conversation_id:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
# Rule 3: Deny - no matching authorization rule
|
||||||
|
return False, f'Artifact {operation} access denied: artifact not in session conversation and not created by this run'
|
||||||
|
|
||||||
|
|
||||||
def _normalize_uuid_list(values: Any) -> list[str]:
|
def _normalize_uuid_list(values: Any) -> list[str]:
|
||||||
"""Normalize a user/config supplied UUID list while preserving order."""
|
"""Normalize a user/config supplied UUID list while preserving order."""
|
||||||
if not isinstance(values, list):
|
if not isinstance(values, list):
|
||||||
@@ -1542,6 +1583,169 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
self.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True)
|
self.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True)
|
||||||
return handler.ActionResponse.error(message=f'Event page error: {e}')
|
return handler.ActionResponse.error(message=f'Event page error: {e}')
|
||||||
|
|
||||||
|
# ================= Artifact APIs =================
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.ARTIFACT_METADATA)
|
||||||
|
async def artifact_metadata(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""Get artifact metadata.
|
||||||
|
|
||||||
|
Requires run_id authorization. Only allows access to artifacts
|
||||||
|
in current run's conversation or created by current run.
|
||||||
|
"""
|
||||||
|
run_id = data.get('run_id')
|
||||||
|
artifact_id = data.get('artifact_id')
|
||||||
|
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||||
|
|
||||||
|
if not run_id:
|
||||||
|
return handler.ActionResponse.error(message='run_id is required')
|
||||||
|
|
||||||
|
if not artifact_id:
|
||||||
|
return handler.ActionResponse.error(message='artifact_id is required')
|
||||||
|
|
||||||
|
# Validate run session
|
||||||
|
session_registry = get_session_registry()
|
||||||
|
session = await session_registry.get(run_id)
|
||||||
|
if not session:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Run session {run_id} not found or expired'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate caller plugin identity
|
||||||
|
if caller_plugin_identity:
|
||||||
|
session_plugin_identity = session.get('plugin_identity')
|
||||||
|
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Plugin identity mismatch for run_id {run_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check artifact permission from session.permissions (from descriptor.permissions)
|
||||||
|
permissions = session.get('permissions', {})
|
||||||
|
artifact_permissions = permissions.get('artifacts', [])
|
||||||
|
if 'metadata' not in artifact_permissions:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message='Artifact metadata access not authorized'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get artifact metadata
|
||||||
|
from ..agent.runner.artifact_store import ArtifactStore
|
||||||
|
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata = await store.get_metadata(artifact_id)
|
||||||
|
if not metadata:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Artifact {artifact_id} not found'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate artifact access scope
|
||||||
|
is_allowed, error_msg = _validate_artifact_access(session, metadata, 'metadata')
|
||||||
|
if not is_allowed:
|
||||||
|
return handler.ActionResponse.error(message=error_msg)
|
||||||
|
|
||||||
|
return handler.ActionResponse.success(data=metadata)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'ARTIFACT_METADATA error: {e}', exc_info=True)
|
||||||
|
return handler.ActionResponse.error(message=f'Artifact metadata error: {e}')
|
||||||
|
|
||||||
|
@self.action(PluginToRuntimeAction.ARTIFACT_READ)
|
||||||
|
async def artifact_read(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
|
"""Read artifact content.
|
||||||
|
|
||||||
|
Requires run_id authorization. Only allows access to artifacts
|
||||||
|
in current run's conversation or created by current run.
|
||||||
|
Supports range reads with offset/limit.
|
||||||
|
"""
|
||||||
|
run_id = data.get('run_id')
|
||||||
|
artifact_id = data.get('artifact_id')
|
||||||
|
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||||
|
|
||||||
|
if not run_id:
|
||||||
|
return handler.ActionResponse.error(message='run_id is required')
|
||||||
|
|
||||||
|
if not artifact_id:
|
||||||
|
return handler.ActionResponse.error(message='artifact_id is required')
|
||||||
|
|
||||||
|
# Validate and parse offset
|
||||||
|
offset = data.get('offset', 0)
|
||||||
|
if not isinstance(offset, int):
|
||||||
|
try:
|
||||||
|
offset = int(offset)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return handler.ActionResponse.error(message='offset must be an integer')
|
||||||
|
if offset < 0:
|
||||||
|
return handler.ActionResponse.error(message='offset must be >= 0')
|
||||||
|
|
||||||
|
# Validate and parse limit if provided
|
||||||
|
limit = data.get('limit')
|
||||||
|
if limit is not None:
|
||||||
|
if not isinstance(limit, int):
|
||||||
|
try:
|
||||||
|
limit = int(limit)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return handler.ActionResponse.error(message='limit must be an integer')
|
||||||
|
if limit <= 0:
|
||||||
|
return handler.ActionResponse.error(message='limit must be > 0')
|
||||||
|
|
||||||
|
# Validate run session
|
||||||
|
session_registry = get_session_registry()
|
||||||
|
session = await session_registry.get(run_id)
|
||||||
|
if not session:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Run session {run_id} not found or expired'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate caller plugin identity
|
||||||
|
if caller_plugin_identity:
|
||||||
|
session_plugin_identity = session.get('plugin_identity')
|
||||||
|
if session_plugin_identity and caller_plugin_identity != session_plugin_identity:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Plugin identity mismatch for run_id {run_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check artifact permission from session.permissions (from descriptor.permissions)
|
||||||
|
permissions = session.get('permissions', {})
|
||||||
|
artifact_permissions = permissions.get('artifacts', [])
|
||||||
|
if 'read' not in artifact_permissions:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message='Artifact read access not authorized'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get artifact metadata first to validate access
|
||||||
|
from ..agent.runner.artifact_store import ArtifactStore
|
||||||
|
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata = await store.get_metadata(artifact_id)
|
||||||
|
if not metadata:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Artifact {artifact_id} not found'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate artifact access scope
|
||||||
|
is_allowed, error_msg = _validate_artifact_access(session, metadata, 'read')
|
||||||
|
if not is_allowed:
|
||||||
|
return handler.ActionResponse.error(message=error_msg)
|
||||||
|
|
||||||
|
# Read artifact content (validates offset/limit internally)
|
||||||
|
result = await store.read_artifact(
|
||||||
|
artifact_id=artifact_id,
|
||||||
|
offset=offset,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return handler.ActionResponse.error(
|
||||||
|
message=f'Failed to read artifact {artifact_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
return handler.ActionResponse.success(data=result)
|
||||||
|
except ValueError as e:
|
||||||
|
# Offset/limit validation error
|
||||||
|
return handler.ActionResponse.error(message=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'ARTIFACT_READ error: {e}', exc_info=True)
|
||||||
|
return handler.ActionResponse.error(message=f'Artifact read error: {e}')
|
||||||
|
|
||||||
@self.action(CommonAction.PING)
|
@self.action(CommonAction.PING)
|
||||||
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
|
async def ping(data: dict[str, Any]) -> handler.ActionResponse:
|
||||||
"""Ping"""
|
"""Ping"""
|
||||||
|
|||||||
625
tests/unit_tests/agent/test_artifact_store.py
Normal file
625
tests/unit_tests/agent/test_artifact_store.py
Normal file
@@ -0,0 +1,625 @@
|
|||||||
|
"""Tests for ArtifactStore and artifact action handlers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from langbot.pkg.agent.runner.artifact_store import ArtifactStore
|
||||||
|
from langbot.pkg.agent.runner.session_registry import (
|
||||||
|
AgentRunSessionRegistry,
|
||||||
|
get_session_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestArtifactStore:
|
||||||
|
"""Test ArtifactStore operations."""
|
||||||
|
|
||||||
|
def _make_mock_engine(self):
|
||||||
|
"""Create a mock database engine for AsyncSession-based store.
|
||||||
|
|
||||||
|
Note: The new store uses AsyncSession, so we need to mock
|
||||||
|
the session factory behavior.
|
||||||
|
"""
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
|
||||||
|
engine = MagicMock(spec=AsyncEngine)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_artifact_generates_id(self):
|
||||||
|
"""Test register_artifact generates ID if not provided."""
|
||||||
|
engine = self._make_mock_engine()
|
||||||
|
store = ArtifactStore(engine)
|
||||||
|
|
||||||
|
# Mock the session factory
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.add = MagicMock()
|
||||||
|
mock_session.commit = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
artifact_id = await store.register_artifact(
|
||||||
|
artifact_id=None,
|
||||||
|
artifact_type="image",
|
||||||
|
source="platform",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert artifact_id is not None
|
||||||
|
assert len(artifact_id) == 36 # UUID format
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_artifact_with_content(self):
|
||||||
|
"""Test register_artifact stores content in BinaryStorage."""
|
||||||
|
engine = self._make_mock_engine()
|
||||||
|
store = ArtifactStore(engine)
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.add = MagicMock()
|
||||||
|
mock_session.commit = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
content = b"test image content"
|
||||||
|
artifact_id = await store.register_artifact(
|
||||||
|
artifact_id="art_001",
|
||||||
|
artifact_type="image",
|
||||||
|
source="platform",
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert artifact_id == "art_001"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_artifact_with_storage_key(self):
|
||||||
|
"""Test register_artifact with pre-existing storage_key."""
|
||||||
|
engine = self._make_mock_engine()
|
||||||
|
store = ArtifactStore(engine)
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.add = MagicMock()
|
||||||
|
mock_session.commit = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
artifact_id = await store.register_artifact(
|
||||||
|
artifact_id="art_002",
|
||||||
|
artifact_type="file",
|
||||||
|
source="runner",
|
||||||
|
storage_key="existing_key",
|
||||||
|
storage_type="binary_storage",
|
||||||
|
size_bytes=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert artifact_id == "art_002"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_metadata_not_found(self):
|
||||||
|
"""Test get_metadata returns None if not found."""
|
||||||
|
engine = self._make_mock_engine()
|
||||||
|
store = ArtifactStore(engine)
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = None
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
metadata = await store.get_metadata("nonexistent")
|
||||||
|
|
||||||
|
assert metadata is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_artifact_validates_offset(self):
|
||||||
|
"""Test read_artifact rejects negative offset."""
|
||||||
|
engine = self._make_mock_engine()
|
||||||
|
store = ArtifactStore(engine)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="offset must be >= 0"):
|
||||||
|
await store.read_artifact("art_001", offset=-1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_artifact_validates_limit(self):
|
||||||
|
"""Test read_artifact rejects zero or negative limit."""
|
||||||
|
engine = self._make_mock_engine()
|
||||||
|
store = ArtifactStore(engine)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="limit must be > 0"):
|
||||||
|
await store.read_artifact("art_001", limit=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="limit must be > 0"):
|
||||||
|
await store.read_artifact("art_001", limit=-5)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_artifact_not_found(self):
|
||||||
|
"""Test read_artifact returns None if not found."""
|
||||||
|
engine = self._make_mock_engine()
|
||||||
|
store = ArtifactStore(engine)
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = None
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
result = await store.read_artifact("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestArtifactAuthorization:
|
||||||
|
"""Test artifact action handler authorization."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session_registry(self):
|
||||||
|
"""Create a fresh session registry for testing."""
|
||||||
|
# Reset global registry
|
||||||
|
import langbot.pkg.agent.runner.session_registry as reg
|
||||||
|
reg._global_registry = None
|
||||||
|
return get_session_registry()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_handler(self):
|
||||||
|
"""Create a mock handler for testing actions."""
|
||||||
|
from langbot_plugin.runtime.io.handler import Handler
|
||||||
|
|
||||||
|
class MockHandler(Handler):
|
||||||
|
def __init__(self):
|
||||||
|
self._responses = {}
|
||||||
|
|
||||||
|
async def call_action(self, action, data, timeout=30):
|
||||||
|
# Simulate error response for missing run_id
|
||||||
|
if not data.get("run_id"):
|
||||||
|
return {"ok": False, "message": "run_id is required"}
|
||||||
|
return {"ok": True, "data": {}}
|
||||||
|
|
||||||
|
return MockHandler()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_artifact_metadata_requires_run_id(self, mock_handler):
|
||||||
|
"""Test artifact_metadata requires run_id."""
|
||||||
|
result = await mock_handler.call_action(
|
||||||
|
"artifact_metadata",
|
||||||
|
{"run_id": None, "artifact_id": "art_001"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.get("ok") is False or "error" in str(result).lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_artifact_read_requires_run_id(self, mock_handler):
|
||||||
|
"""Test artifact_read requires run_id."""
|
||||||
|
result = await mock_handler.call_action(
|
||||||
|
"artifact_read",
|
||||||
|
{"run_id": None, "artifact_id": "art_001"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.get("ok") is False or "error" in str(result).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestArtifactAccessValidation:
|
||||||
|
"""Test _validate_artifact_access authorization rules."""
|
||||||
|
|
||||||
|
def _call_validate(self, session, metadata, operation="metadata"):
|
||||||
|
"""Helper to call the validation function."""
|
||||||
|
from langbot.pkg.plugin.handler import _validate_artifact_access
|
||||||
|
return _validate_artifact_access(session, metadata, operation)
|
||||||
|
|
||||||
|
def test_global_artifact_denied_by_default(self):
|
||||||
|
"""Artifacts without conversation_id are denied by default (no global access)."""
|
||||||
|
session = {
|
||||||
|
"run_id": "run_001",
|
||||||
|
"conversation_id": "conv_001",
|
||||||
|
"permissions": {"artifacts": ["metadata", "read"]},
|
||||||
|
}
|
||||||
|
metadata = {
|
||||||
|
"artifact_id": "art_global",
|
||||||
|
"conversation_id": None, # No conversation scope
|
||||||
|
"run_id": None, # Not created by any run
|
||||||
|
}
|
||||||
|
|
||||||
|
is_allowed, error = self._call_validate(session, metadata)
|
||||||
|
assert is_allowed is False
|
||||||
|
assert "denied" in error.lower()
|
||||||
|
|
||||||
|
def test_own_run_artifact_allowed(self):
|
||||||
|
"""Artifacts created by same run are allowed (even cross-conversation)."""
|
||||||
|
session = {
|
||||||
|
"run_id": "run_001",
|
||||||
|
"conversation_id": "conv_001",
|
||||||
|
"permissions": {"artifacts": ["metadata", "read"]},
|
||||||
|
}
|
||||||
|
metadata = {
|
||||||
|
"artifact_id": "art_001",
|
||||||
|
"conversation_id": "conv_other", # Different conversation
|
||||||
|
"run_id": "run_001", # Same run
|
||||||
|
}
|
||||||
|
|
||||||
|
is_allowed, error = self._call_validate(session, metadata)
|
||||||
|
assert is_allowed is True
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_same_conversation_allowed(self):
|
||||||
|
"""Artifacts in same conversation are allowed."""
|
||||||
|
session = {
|
||||||
|
"run_id": "run_001",
|
||||||
|
"conversation_id": "conv_001",
|
||||||
|
"permissions": {"artifacts": ["metadata", "read"]},
|
||||||
|
}
|
||||||
|
metadata = {
|
||||||
|
"artifact_id": "art_001",
|
||||||
|
"conversation_id": "conv_001", # Same as session
|
||||||
|
"run_id": "run_other", # Different run
|
||||||
|
}
|
||||||
|
|
||||||
|
is_allowed, error = self._call_validate(session, metadata)
|
||||||
|
assert is_allowed is True
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
def test_different_conversation_and_run_denied(self):
|
||||||
|
"""Artifacts in different conversation and different run are denied."""
|
||||||
|
session = {
|
||||||
|
"run_id": "run_001",
|
||||||
|
"conversation_id": "conv_001",
|
||||||
|
"permissions": {"artifacts": ["metadata", "read"]},
|
||||||
|
}
|
||||||
|
metadata = {
|
||||||
|
"artifact_id": "art_001",
|
||||||
|
"conversation_id": "conv_other", # Different conversation
|
||||||
|
"run_id": "run_other", # Different run
|
||||||
|
}
|
||||||
|
|
||||||
|
is_allowed, error = self._call_validate(session, metadata)
|
||||||
|
assert is_allowed is False
|
||||||
|
assert "denied" in error.lower()
|
||||||
|
|
||||||
|
def test_session_without_conversation_denied_for_conversation_artifact(self):
|
||||||
|
"""Session without conversation_id cannot access conversation-scoped artifacts."""
|
||||||
|
session = {
|
||||||
|
"run_id": "run_001",
|
||||||
|
"conversation_id": None, # No conversation
|
||||||
|
"permissions": {"artifacts": ["metadata", "read"]},
|
||||||
|
}
|
||||||
|
metadata = {
|
||||||
|
"artifact_id": "art_001",
|
||||||
|
"conversation_id": "conv_001", # Has conversation
|
||||||
|
"run_id": "run_other", # Different run
|
||||||
|
}
|
||||||
|
|
||||||
|
is_allowed, error = self._call_validate(session, metadata)
|
||||||
|
assert is_allowed is False
|
||||||
|
|
||||||
|
def test_session_without_conversation_allowed_for_own_artifact(self):
|
||||||
|
"""Session without conversation can access artifacts it created."""
|
||||||
|
session = {
|
||||||
|
"run_id": "run_001",
|
||||||
|
"conversation_id": None, # No conversation
|
||||||
|
"permissions": {"artifacts": ["metadata", "read"]},
|
||||||
|
}
|
||||||
|
metadata = {
|
||||||
|
"artifact_id": "art_001",
|
||||||
|
"conversation_id": "conv_001", # Has conversation
|
||||||
|
"run_id": "run_001", # Same run (created by this run)
|
||||||
|
}
|
||||||
|
|
||||||
|
is_allowed, error = self._call_validate(session, metadata)
|
||||||
|
assert is_allowed is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextAccessArtifactAPIs:
|
||||||
|
"""Test ContextAccess reflects artifact API permissions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_access_has_artifact_apis_when_permitted(self):
|
||||||
|
"""Test ContextAccess shows artifact APIs when permissions allow."""
|
||||||
|
# This tests the context builder logic
|
||||||
|
# When artifact permissions include 'metadata' and 'read',
|
||||||
|
# available_apis should reflect that
|
||||||
|
permissions = {"artifacts": ["metadata", "read"]}
|
||||||
|
|
||||||
|
# Check that permissions are properly interpreted
|
||||||
|
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
|
||||||
|
artifact_read_enabled = "read" in permissions.get("artifacts", [])
|
||||||
|
|
||||||
|
assert artifact_metadata_enabled is True
|
||||||
|
assert artifact_read_enabled is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_access_no_artifact_apis_without_permission(self):
|
||||||
|
"""Test ContextAccess hides artifact APIs when permissions denied."""
|
||||||
|
permissions = {"artifacts": []}
|
||||||
|
|
||||||
|
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
|
||||||
|
artifact_read_enabled = "read" in permissions.get("artifacts", [])
|
||||||
|
|
||||||
|
assert artifact_metadata_enabled is False
|
||||||
|
assert artifact_read_enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestArtifactMetadataFieldAlignment:
|
||||||
|
"""Test that Host returns metadata compatible with SDK ArtifactMetadata."""
|
||||||
|
|
||||||
|
def test_row_to_public_dict_excludes_host_only_fields(self):
|
||||||
|
"""_row_to_public_dict should not return Host-only fields."""
|
||||||
|
from langbot.pkg.agent.runner.artifact_store import ArtifactStore
|
||||||
|
from langbot.pkg.entity.persistence.artifact import AgentArtifact
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
# Create a mock row
|
||||||
|
mock_row = MagicMock(spec=AgentArtifact)
|
||||||
|
mock_row.artifact_id = "art_001"
|
||||||
|
mock_row.artifact_type = "image"
|
||||||
|
mock_row.mime_type = "image/png"
|
||||||
|
mock_row.name = "test.png"
|
||||||
|
mock_row.size_bytes = 1024
|
||||||
|
mock_row.sha256 = "abc123"
|
||||||
|
mock_row.source = "platform"
|
||||||
|
mock_row.conversation_id = "conv_001"
|
||||||
|
mock_row.run_id = "run_001"
|
||||||
|
mock_row.runner_id = "plugin:test/plugin/runner"
|
||||||
|
mock_row.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
mock_row.expires_at = None
|
||||||
|
mock_row.metadata_json = None
|
||||||
|
|
||||||
|
# These are Host-only fields that should NOT be in output
|
||||||
|
# (they don't exist in SDK ArtifactMetadata)
|
||||||
|
mock_row.bot_id = "bot_001"
|
||||||
|
mock_row.workspace_id = "ws_001"
|
||||||
|
mock_row.storage_key = "artifact:art_001"
|
||||||
|
mock_row.storage_type = "binary_storage"
|
||||||
|
|
||||||
|
store = ArtifactStore(MagicMock())
|
||||||
|
result = store._row_to_public_dict(mock_row)
|
||||||
|
|
||||||
|
# SDK-compatible fields should be present
|
||||||
|
assert result["artifact_id"] == "art_001"
|
||||||
|
assert result["artifact_type"] == "image"
|
||||||
|
assert result["source"] == "platform"
|
||||||
|
assert result["conversation_id"] == "conv_001"
|
||||||
|
assert result["run_id"] == "run_001"
|
||||||
|
|
||||||
|
# Host-only fields should NOT be present
|
||||||
|
assert "bot_id" not in result
|
||||||
|
assert "workspace_id" not in result
|
||||||
|
assert "storage_key" not in result
|
||||||
|
assert "storage_type" not in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionRegistryPermissions:
|
||||||
|
"""Test that session registry stores and retrieves permissions correctly."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_registry(self):
|
||||||
|
"""Create a fresh session registry for testing."""
|
||||||
|
import langbot.pkg.agent.runner.session_registry as reg
|
||||||
|
reg._global_registry = None
|
||||||
|
return get_session_registry()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_stores_permissions(self, session_registry):
|
||||||
|
"""Test that register() stores permissions from descriptor."""
|
||||||
|
await session_registry.register(
|
||||||
|
run_id="run_001",
|
||||||
|
runner_id="plugin:author/plugin/runner",
|
||||||
|
query_id=None,
|
||||||
|
plugin_identity="author/plugin",
|
||||||
|
resources={
|
||||||
|
"models": [],
|
||||||
|
"tools": [],
|
||||||
|
"knowledge_bases": [],
|
||||||
|
"files": [],
|
||||||
|
"storage": {"plugin_storage": True, "workspace_storage": False},
|
||||||
|
"platform_capabilities": {},
|
||||||
|
},
|
||||||
|
permissions={
|
||||||
|
"artifacts": ["metadata", "read"],
|
||||||
|
"history": ["page"],
|
||||||
|
"events": ["get"],
|
||||||
|
},
|
||||||
|
conversation_id="conv_001",
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await session_registry.get("run_001")
|
||||||
|
assert session is not None
|
||||||
|
assert session["permissions"]["artifacts"] == ["metadata", "read"]
|
||||||
|
assert session["permissions"]["history"] == ["page"]
|
||||||
|
assert session["permissions"]["events"] == ["get"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_with_empty_permissions(self, session_registry):
|
||||||
|
"""Test that register() handles empty permissions."""
|
||||||
|
await session_registry.register(
|
||||||
|
run_id="run_002",
|
||||||
|
runner_id="plugin:author/plugin/runner",
|
||||||
|
query_id=None,
|
||||||
|
plugin_identity="author/plugin",
|
||||||
|
resources={
|
||||||
|
"models": [],
|
||||||
|
"tools": [],
|
||||||
|
"knowledge_bases": [],
|
||||||
|
"files": [],
|
||||||
|
"storage": {"plugin_storage": True, "workspace_storage": False},
|
||||||
|
"platform_capabilities": {},
|
||||||
|
},
|
||||||
|
permissions={},
|
||||||
|
conversation_id="conv_001",
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await session_registry.get("run_002")
|
||||||
|
assert session is not None
|
||||||
|
assert session["permissions"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestArtifactStoreRealSQLite:
|
||||||
|
"""Test ArtifactStore with real SQLite database."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_engine(self):
|
||||||
|
"""Create an in-memory SQLite database for testing."""
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlalchemy import text
|
||||||
|
from langbot.pkg.entity.persistence.base import Base
|
||||||
|
from langbot.pkg.entity.persistence.artifact import AgentArtifact
|
||||||
|
from langbot.pkg.entity.persistence.bstorage import BinaryStorage
|
||||||
|
|
||||||
|
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||||
|
|
||||||
|
# Create tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
# Create tables manually for in-memory DB
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
yield engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_get_metadata_round_trip(self, db_engine):
|
||||||
|
"""Test register_artifact -> get_metadata round trip with real DB."""
|
||||||
|
store = ArtifactStore(db_engine)
|
||||||
|
|
||||||
|
# Register artifact with content
|
||||||
|
content = b"test image content for round trip"
|
||||||
|
artifact_id = await store.register_artifact(
|
||||||
|
artifact_id="art_real_001",
|
||||||
|
artifact_type="image",
|
||||||
|
source="platform",
|
||||||
|
mime_type="image/png",
|
||||||
|
name="test.png",
|
||||||
|
content=content,
|
||||||
|
conversation_id="conv_001",
|
||||||
|
run_id="run_001",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert artifact_id == "art_real_001"
|
||||||
|
|
||||||
|
# Get metadata
|
||||||
|
metadata = await store.get_metadata(artifact_id)
|
||||||
|
assert metadata is not None
|
||||||
|
assert metadata["artifact_id"] == "art_real_001"
|
||||||
|
assert metadata["artifact_type"] == "image"
|
||||||
|
assert metadata["mime_type"] == "image/png"
|
||||||
|
assert metadata["source"] == "platform"
|
||||||
|
assert metadata["conversation_id"] == "conv_001"
|
||||||
|
assert metadata["run_id"] == "run_001"
|
||||||
|
|
||||||
|
# Verify Host-only fields are NOT in public metadata
|
||||||
|
assert "storage_key" not in metadata
|
||||||
|
assert "storage_type" not in metadata
|
||||||
|
assert "bot_id" not in metadata
|
||||||
|
assert "workspace_id" not in metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_artifact_round_trip(self, db_engine):
|
||||||
|
"""Test register_artifact -> read_artifact round trip with real DB."""
|
||||||
|
store = ArtifactStore(db_engine)
|
||||||
|
|
||||||
|
# Register artifact with content
|
||||||
|
content = b"test file content for read test"
|
||||||
|
artifact_id = await store.register_artifact(
|
||||||
|
artifact_id="art_real_002",
|
||||||
|
artifact_type="file",
|
||||||
|
source="runner",
|
||||||
|
mime_type="text/plain",
|
||||||
|
name="test.txt",
|
||||||
|
content=content,
|
||||||
|
conversation_id="conv_001",
|
||||||
|
run_id="run_001",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read artifact
|
||||||
|
result = await store.read_artifact(artifact_id)
|
||||||
|
assert result is not None
|
||||||
|
assert result["artifact_id"] == "art_real_002"
|
||||||
|
assert result["mime_type"] == "text/plain"
|
||||||
|
assert result["offset"] == 0
|
||||||
|
assert result["length"] == len(content)
|
||||||
|
assert result["has_more"] is False
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
decoded_content = base64.b64decode(result["content_base64"])
|
||||||
|
assert decoded_content == content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_artifact_with_offset_limit(self, db_engine):
|
||||||
|
"""Test read_artifact with offset and limit."""
|
||||||
|
store = ArtifactStore(db_engine)
|
||||||
|
|
||||||
|
# Register artifact with content
|
||||||
|
content = b"0123456789" * 100 # 1000 bytes
|
||||||
|
artifact_id = await store.register_artifact(
|
||||||
|
artifact_id="art_real_003",
|
||||||
|
artifact_type="file",
|
||||||
|
source="runner",
|
||||||
|
mime_type="application/octet-stream",
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read with offset
|
||||||
|
result = await store.read_artifact(artifact_id, offset=100, limit=100)
|
||||||
|
assert result is not None
|
||||||
|
assert result["offset"] == 100
|
||||||
|
assert result["length"] == 100
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
decoded_content = base64.b64decode(result["content_base64"])
|
||||||
|
assert decoded_content == content[100:200]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_artifact_has_more(self, db_engine):
|
||||||
|
"""Test read_artifact sets has_more correctly."""
|
||||||
|
store = ArtifactStore(db_engine)
|
||||||
|
|
||||||
|
# Register artifact with content
|
||||||
|
content = b"0123456789" * 100 # 1000 bytes
|
||||||
|
artifact_id = await store.register_artifact(
|
||||||
|
artifact_id="art_real_004",
|
||||||
|
artifact_type="file",
|
||||||
|
source="runner",
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read with limit smaller than content
|
||||||
|
result = await store.read_artifact(artifact_id, offset=0, limit=100)
|
||||||
|
assert result is not None
|
||||||
|
assert result["has_more"] is True
|
||||||
|
assert result["length"] == 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_metadata_sdk_validation(self, db_engine):
|
||||||
|
"""Test that metadata can be validated by SDK ArtifactMetadata."""
|
||||||
|
from langbot_plugin.api.entities.builtin.agent_runner.artifact import ArtifactMetadata
|
||||||
|
|
||||||
|
store = ArtifactStore(db_engine)
|
||||||
|
|
||||||
|
# Register artifact
|
||||||
|
artifact_id = await store.register_artifact(
|
||||||
|
artifact_id="art_real_005",
|
||||||
|
artifact_type="file",
|
||||||
|
source="runner",
|
||||||
|
mime_type="application/pdf",
|
||||||
|
name="document.pdf",
|
||||||
|
size_bytes=1024,
|
||||||
|
conversation_id="conv_001",
|
||||||
|
run_id="run_001",
|
||||||
|
runner_id="plugin:test/plugin/runner",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get metadata
|
||||||
|
metadata = await store.get_metadata(artifact_id)
|
||||||
|
assert metadata is not None
|
||||||
|
|
||||||
|
# Should not raise ValidationError
|
||||||
|
validated = ArtifactMetadata.model_validate(metadata)
|
||||||
|
assert validated.artifact_id == "art_real_005"
|
||||||
|
assert validated.artifact_type == "file"
|
||||||
@@ -73,49 +73,78 @@ class TestEventLogStore:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_append_event(self, mock_db_engine):
|
async def test_append_event(self, mock_db_engine):
|
||||||
"""Test appending an event to EventLog."""
|
"""Test appending an event to EventLog."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
store = EventLogStore(mock_db_engine)
|
store = EventLogStore(mock_db_engine)
|
||||||
|
|
||||||
event_id = await store.append_event(
|
mock_session = AsyncMock()
|
||||||
event_id="evt_1",
|
mock_session.add = MagicMock()
|
||||||
event_type="message.received",
|
mock_session.commit = AsyncMock()
|
||||||
source="platform",
|
|
||||||
bot_id="bot_1",
|
|
||||||
conversation_id="conv_1",
|
|
||||||
actor_type="user",
|
|
||||||
actor_id="user_1",
|
|
||||||
input_summary="Hello world",
|
|
||||||
run_id="run_1",
|
|
||||||
runner_id="plugin:test/plugin/runner",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert event_id == "evt_1"
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
event_id = await store.append_event(
|
||||||
|
event_id="evt_1",
|
||||||
|
event_type="message.received",
|
||||||
|
source="platform",
|
||||||
|
bot_id="bot_1",
|
||||||
|
conversation_id="conv_1",
|
||||||
|
actor_type="user",
|
||||||
|
actor_id="user_1",
|
||||||
|
input_summary="Hello world",
|
||||||
|
run_id="run_1",
|
||||||
|
runner_id="plugin:test/plugin/runner",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event_id == "evt_1"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_append_event_truncates_input_summary(self, mock_db_engine):
|
async def test_append_event_truncates_input_summary(self, mock_db_engine):
|
||||||
"""Test that long input summaries are truncated."""
|
"""Test that long input summaries are truncated."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
store = EventLogStore(mock_db_engine)
|
store = EventLogStore(mock_db_engine)
|
||||||
|
|
||||||
long_text = "x" * 2000
|
mock_session = AsyncMock()
|
||||||
event_id = await store.append_event(
|
mock_session.add = MagicMock()
|
||||||
event_id="evt_2",
|
mock_session.commit = AsyncMock()
|
||||||
event_type="message.received",
|
|
||||||
source="platform",
|
|
||||||
input_summary=long_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert event_id == "evt_2"
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
long_text = "x" * 2000
|
||||||
|
event_id = await store.append_event(
|
||||||
|
event_id="evt_2",
|
||||||
|
event_type="message.received",
|
||||||
|
source="platform",
|
||||||
|
input_summary=long_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event_id == "evt_2"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_page_events_with_conversation_filter(self, mock_db_engine):
|
async def test_page_events_with_conversation_filter(self, mock_db_engine):
|
||||||
"""Test paging events with conversation_id filter."""
|
"""Test paging events with conversation_id filter."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
store = EventLogStore(mock_db_engine)
|
store = EventLogStore(mock_db_engine)
|
||||||
|
|
||||||
items, next_seq, has_more = await store.page_events(
|
mock_result = MagicMock()
|
||||||
conversation_id="conv_1",
|
mock_result.scalars.return_value.all.return_value = []
|
||||||
limit=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(items, list)
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
items, next_seq, has_more = await store.page_events(
|
||||||
|
conversation_id="conv_1",
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(items, list)
|
||||||
|
|
||||||
|
|
||||||
class TestTranscriptStore:
|
class TestTranscriptStore:
|
||||||
@@ -124,75 +153,129 @@ class TestTranscriptStore:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_append_transcript(self, mock_db_engine):
|
async def test_append_transcript(self, mock_db_engine):
|
||||||
"""Test appending a transcript item."""
|
"""Test appending a transcript item."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
store = TranscriptStore(mock_db_engine)
|
store = TranscriptStore(mock_db_engine)
|
||||||
|
|
||||||
transcript_id = await store.append_transcript(
|
mock_session = AsyncMock()
|
||||||
transcript_id=None, # Auto-generate
|
mock_session.add = MagicMock()
|
||||||
event_id="evt_1",
|
mock_session.commit = AsyncMock()
|
||||||
conversation_id="conv_1",
|
|
||||||
role="user",
|
|
||||||
content="Hello",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert transcript_id is not None
|
# Mock _get_next_seq
|
||||||
|
with patch.object(store, '_get_next_seq', return_value=1):
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
transcript_id = await store.append_transcript(
|
||||||
|
transcript_id=None, # Auto-generate
|
||||||
|
event_id="evt_1",
|
||||||
|
conversation_id="conv_1",
|
||||||
|
role="user",
|
||||||
|
content="Hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert transcript_id is not None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_append_transcript_with_artifacts(self, mock_db_engine):
|
async def test_append_transcript_with_artifacts(self, mock_db_engine):
|
||||||
"""Test appending transcript with artifact refs."""
|
"""Test appending transcript with artifact refs."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
store = TranscriptStore(mock_db_engine)
|
store = TranscriptStore(mock_db_engine)
|
||||||
|
|
||||||
transcript_id = await store.append_transcript(
|
mock_session = AsyncMock()
|
||||||
transcript_id=None, # Auto-generate
|
mock_session.add = MagicMock()
|
||||||
event_id="evt_2",
|
mock_session.commit = AsyncMock()
|
||||||
conversation_id="conv_1",
|
|
||||||
role="assistant",
|
|
||||||
content="Here's an image",
|
|
||||||
artifact_refs=[
|
|
||||||
{"artifact_id": "art_1", "artifact_type": "image", "url": "http://example.com/img.png"}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert transcript_id is not None
|
with patch.object(store, '_get_next_seq', return_value=1):
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
transcript_id = await store.append_transcript(
|
||||||
|
transcript_id=None, # Auto-generate
|
||||||
|
event_id="evt_2",
|
||||||
|
conversation_id="conv_1",
|
||||||
|
role="assistant",
|
||||||
|
content="Here's an image",
|
||||||
|
artifact_refs=[
|
||||||
|
{"artifact_id": "art_1", "artifact_type": "image", "url": "http://example.com/img.png"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert transcript_id is not None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_page_transcript_backward(self, mock_db_engine):
|
async def test_page_transcript_backward(self, mock_db_engine):
|
||||||
"""Test paging transcript backward (older items)."""
|
"""Test paging transcript backward (older items)."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
store = TranscriptStore(mock_db_engine)
|
store = TranscriptStore(mock_db_engine)
|
||||||
|
|
||||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
mock_result = MagicMock()
|
||||||
conversation_id="conv_1",
|
mock_result.scalars.return_value.all.return_value = []
|
||||||
limit=10,
|
|
||||||
direction="backward",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(items, list)
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||||
|
conversation_id="conv_1",
|
||||||
|
limit=10,
|
||||||
|
direction="backward",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(items, list)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_page_transcript_has_hard_limit(self, mock_db_engine):
|
async def test_page_transcript_has_hard_limit(self, mock_db_engine):
|
||||||
"""Test that transcript paging has a hard limit."""
|
"""Test that transcript paging has a hard limit."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
store = TranscriptStore(mock_db_engine)
|
store = TranscriptStore(mock_db_engine)
|
||||||
|
|
||||||
# Request more than the hard limit
|
mock_result = MagicMock()
|
||||||
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
mock_result.scalars.return_value.all.return_value = []
|
||||||
conversation_id="conv_1",
|
|
||||||
limit=200, # Request 200, but hard limit is 100
|
|
||||||
)
|
|
||||||
|
|
||||||
# The store should cap at 100
|
mock_session = AsyncMock()
|
||||||
assert len(items) <= store.HARD_LIMIT
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
# Request more than the hard limit
|
||||||
|
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||||
|
conversation_id="conv_1",
|
||||||
|
limit=200, # Request 200, but hard limit is 100
|
||||||
|
)
|
||||||
|
|
||||||
|
# The store should cap at 100
|
||||||
|
assert len(items) <= store.HARD_LIMIT
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_transcript(self, mock_db_engine):
|
async def test_search_transcript(self, mock_db_engine):
|
||||||
"""Test searching transcript."""
|
"""Test searching transcript."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
store = TranscriptStore(mock_db_engine)
|
store = TranscriptStore(mock_db_engine)
|
||||||
|
|
||||||
items = await store.search_transcript(
|
mock_result = MagicMock()
|
||||||
conversation_id="conv_1",
|
mock_result.scalars.return_value.all.return_value = []
|
||||||
query_text="database",
|
|
||||||
top_k=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(items, list)
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
items = await store.search_transcript(
|
||||||
|
conversation_id="conv_1",
|
||||||
|
query_text="database",
|
||||||
|
top_k=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(items, list)
|
||||||
|
|
||||||
|
|
||||||
class TestHistoryPageAuthorization:
|
class TestHistoryPageAuthorization:
|
||||||
@@ -259,50 +342,244 @@ class TestContextAccessPopulation:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_context_access_has_history_apis_when_permitted(self, mock_db_engine):
|
async def test_context_access_has_history_apis_when_permitted(self, mock_db_engine):
|
||||||
"""Test ContextAccess shows available APIs based on permissions."""
|
"""Test ContextAccess shows available APIs based on permissions."""
|
||||||
# This would test the context builder logic
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
# For now we verify the store methods work
|
|
||||||
store = TranscriptStore(mock_db_engine)
|
store = TranscriptStore(mock_db_engine)
|
||||||
|
|
||||||
cursor = await store.get_latest_cursor("conv_1")
|
mock_result = MagicMock()
|
||||||
# Should return None or a cursor string
|
mock_result.scalars.return_value.first.return_value = None
|
||||||
assert cursor is None or isinstance(cursor, str)
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
cursor = await store.get_latest_cursor("conv_1")
|
||||||
|
# Should return None or a cursor string
|
||||||
|
assert cursor is None or isinstance(cursor, str)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_context_access_shows_has_history_before(self, mock_db_engine):
|
async def test_context_access_shows_has_history_before(self, mock_db_engine):
|
||||||
"""Test ContextAccess indicates if history exists."""
|
"""Test ContextAccess indicates if history exists."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
store = TranscriptStore(mock_db_engine)
|
store = TranscriptStore(mock_db_engine)
|
||||||
|
|
||||||
has_history = await store.has_history_before("conv_1", 10)
|
mock_result = MagicMock()
|
||||||
assert isinstance(has_history, bool)
|
mock_result.scalar.return_value = 0
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
with patch.object(store, '_session_factory') as mock_factory:
|
||||||
|
mock_factory.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
has_history = await store.has_history_before("conv_1", 10)
|
||||||
|
assert isinstance(has_history, bool)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventLogStoreRealSQLite:
|
||||||
|
"""Test EventLogStore with real SQLite database."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_engine(self):
|
||||||
|
"""Create an in-memory SQLite database for testing."""
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlalchemy import text
|
||||||
|
from langbot.pkg.entity.persistence.base import Base
|
||||||
|
from langbot.pkg.entity.persistence.event_log import EventLog
|
||||||
|
|
||||||
|
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||||
|
|
||||||
|
# Create tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
yield engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_append_get_event_round_trip(self, db_engine):
|
||||||
|
"""Test append_event -> get_event round trip with real DB."""
|
||||||
|
store = EventLogStore(db_engine)
|
||||||
|
|
||||||
|
# Append event
|
||||||
|
event_id = await store.append_event(
|
||||||
|
event_id="evt_real_001",
|
||||||
|
event_type="message.received",
|
||||||
|
source="platform",
|
||||||
|
bot_id="bot_001",
|
||||||
|
conversation_id="conv_001",
|
||||||
|
actor_type="user",
|
||||||
|
actor_id="user_001",
|
||||||
|
actor_name="Test User",
|
||||||
|
input_summary="Hello world",
|
||||||
|
run_id="run_001",
|
||||||
|
runner_id="plugin:test/plugin/runner",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event_id == "evt_real_001"
|
||||||
|
|
||||||
|
# Get event
|
||||||
|
event = await store.get_event(event_id)
|
||||||
|
assert event is not None
|
||||||
|
assert event["event_id"] == "evt_real_001"
|
||||||
|
assert event["event_type"] == "message.received"
|
||||||
|
assert event["source"] == "platform"
|
||||||
|
assert event["conversation_id"] == "conv_001"
|
||||||
|
assert event["actor_type"] == "user"
|
||||||
|
assert event["actor_id"] == "user_001"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_page_events(self, db_engine):
|
||||||
|
"""Test page_events with real DB."""
|
||||||
|
store = EventLogStore(db_engine)
|
||||||
|
|
||||||
|
# Append multiple events
|
||||||
|
for i in range(5):
|
||||||
|
await store.append_event(
|
||||||
|
event_id=f"evt_real_{i:03d}",
|
||||||
|
event_type="message.received",
|
||||||
|
source="platform",
|
||||||
|
conversation_id="conv_001",
|
||||||
|
input_summary=f"Message {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Page events
|
||||||
|
items, next_seq, has_more = await store.page_events(
|
||||||
|
conversation_id="conv_001",
|
||||||
|
limit=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(items) == 3
|
||||||
|
assert has_more is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_latest_cursor(self, db_engine):
|
||||||
|
"""Test get_latest_cursor with real DB."""
|
||||||
|
store = EventLogStore(db_engine)
|
||||||
|
|
||||||
|
# Append events
|
||||||
|
for i in range(3):
|
||||||
|
await store.append_event(
|
||||||
|
event_id=f"evt_cursor_{i:03d}",
|
||||||
|
event_type="message.received",
|
||||||
|
source="platform",
|
||||||
|
conversation_id="conv_cursor",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get latest cursor
|
||||||
|
cursor = await store.get_latest_cursor("conv_cursor")
|
||||||
|
assert cursor is not None
|
||||||
|
assert int(cursor) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptStoreRealSQLite:
|
||||||
|
"""Test TranscriptStore with real SQLite database."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_engine(self):
|
||||||
|
"""Create an in-memory SQLite database for testing."""
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlalchemy import text
|
||||||
|
from langbot.pkg.entity.persistence.base import Base
|
||||||
|
from langbot.pkg.entity.persistence.transcript import Transcript
|
||||||
|
|
||||||
|
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||||
|
|
||||||
|
# Create tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
yield engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_append_page_transcript_round_trip(self, db_engine):
|
||||||
|
"""Test append_transcript -> page_transcript round trip with real DB."""
|
||||||
|
store = TranscriptStore(db_engine)
|
||||||
|
|
||||||
|
# Append transcript items
|
||||||
|
for i in range(3):
|
||||||
|
await store.append_transcript(
|
||||||
|
transcript_id=f"trans_real_{i:03d}",
|
||||||
|
event_id=f"evt_{i:03d}",
|
||||||
|
conversation_id="conv_001",
|
||||||
|
role="user" if i % 2 == 0 else "assistant",
|
||||||
|
content=f"Message {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Page transcript
|
||||||
|
items, next_seq, prev_seq, has_more = await store.page_transcript(
|
||||||
|
conversation_id="conv_001",
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(items) == 3
|
||||||
|
assert items[0]["conversation_id"] == "conv_001"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_transcript_real_db(self, db_engine):
|
||||||
|
"""Test search_transcript with real DB."""
|
||||||
|
store = TranscriptStore(db_engine)
|
||||||
|
|
||||||
|
# Append transcript items
|
||||||
|
await store.append_transcript(
|
||||||
|
transcript_id="trans_search_001",
|
||||||
|
event_id="evt_search_001",
|
||||||
|
conversation_id="conv_search",
|
||||||
|
role="user",
|
||||||
|
content="I want to learn about databases",
|
||||||
|
)
|
||||||
|
await store.append_transcript(
|
||||||
|
transcript_id="trans_search_002",
|
||||||
|
event_id="evt_search_002",
|
||||||
|
conversation_id="conv_search",
|
||||||
|
role="assistant",
|
||||||
|
content="Here is information about databases",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search for "database"
|
||||||
|
items = await store.search_transcript(
|
||||||
|
conversation_id="conv_search",
|
||||||
|
query_text="database",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should find at least one match
|
||||||
|
assert len(items) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_latest_cursor_real_db(self, db_engine):
|
||||||
|
"""Test get_latest_cursor with real DB."""
|
||||||
|
store = TranscriptStore(db_engine)
|
||||||
|
|
||||||
|
# Append transcript items
|
||||||
|
for i in range(3):
|
||||||
|
await store.append_transcript(
|
||||||
|
transcript_id=f"trans_cursor_{i:03d}",
|
||||||
|
event_id=f"evt_cursor_{i:03d}",
|
||||||
|
conversation_id="conv_cursor",
|
||||||
|
role="user",
|
||||||
|
content=f"Message {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get latest cursor
|
||||||
|
cursor = await store.get_latest_cursor("conv_cursor")
|
||||||
|
assert cursor is not None
|
||||||
|
assert int(cursor) > 0
|
||||||
|
|
||||||
|
|
||||||
# Fixtures
|
# Fixtures
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_db_engine():
|
def mock_db_engine():
|
||||||
"""Create a mock database engine."""
|
"""Create a mock database engine for AsyncSession-based stores."""
|
||||||
from unittest.mock import MagicMock, AsyncMock
|
from unittest.mock import MagicMock
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
|
||||||
engine = MagicMock(spec=AsyncEngine)
|
engine = MagicMock(spec=AsyncEngine)
|
||||||
|
|
||||||
# Mock connection
|
|
||||||
mock_conn = MagicMock()
|
|
||||||
mock_result = MagicMock()
|
|
||||||
mock_result.fetchone.return_value = None
|
|
||||||
mock_result.fetchall.return_value = []
|
|
||||||
mock_result.scalar.return_value = 0
|
|
||||||
mock_conn.execute = AsyncMock(return_value=mock_result)
|
|
||||||
mock_conn.commit = AsyncMock()
|
|
||||||
|
|
||||||
# Create async context manager for connect()
|
|
||||||
class AsyncConnectContextManager:
|
|
||||||
async def __aenter__(self):
|
|
||||||
return mock_conn
|
|
||||||
async def __aexit__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# connect() should return an async context manager
|
|
||||||
engine.connect = MagicMock(return_value=AsyncConnectContextManager())
|
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user