mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-04 04:54:36 +00:00
feat(agent-runner): persist created artifacts
This commit is contained in:
@@ -24,9 +24,14 @@ from .pipeline_compat_adapter import PipelineCompatAdapter
|
||||
from .errors import (
|
||||
RunnerNotFoundError,
|
||||
RunnerExecutionError,
|
||||
RunnerProtocolError,
|
||||
)
|
||||
|
||||
|
||||
# Maximum inline artifact content size (1MB)
|
||||
MAX_ARTIFACT_INLINE_BYTES = 1 * 1024 * 1024
|
||||
|
||||
|
||||
class AgentRunOrchestrator:
|
||||
"""Orchestrator for agent runner execution.
|
||||
|
||||
@@ -144,9 +149,25 @@ class AgentRunOrchestrator:
|
||||
event_log_id=event_log_id,
|
||||
)
|
||||
|
||||
# Track artifact refs for assistant transcript (cleared after each message.completed)
|
||||
pending_artifact_refs: list[dict[str, typing.Any]] = []
|
||||
|
||||
try:
|
||||
# Run via plugin connector
|
||||
async for result_dict in self._invoke_runner(descriptor, context):
|
||||
# Handle artifact.created first - consume before normalizer
|
||||
if result_dict.get('type') == 'artifact.created':
|
||||
artifact_ref = await self._handle_artifact_created(
|
||||
result_dict=result_dict,
|
||||
event=event,
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
)
|
||||
pending_artifact_refs.append(artifact_ref)
|
||||
# Pass to normalizer for logging, but don't yield to pipeline
|
||||
await self.result_normalizer.normalize(result_dict, descriptor)
|
||||
continue
|
||||
|
||||
# Handle state.updated first - consume before normalizer
|
||||
if result_dict.get('type') == 'state.updated':
|
||||
self._handle_state_updated_event(result_dict, event, descriptor)
|
||||
@@ -156,11 +177,20 @@ class AgentRunOrchestrator:
|
||||
|
||||
# Handle message.completed - write to Transcript
|
||||
if result_dict.get('type') == 'message.completed' and event.conversation_id:
|
||||
# Merge pending artifact refs with message's own refs
|
||||
merged_refs = self._merge_artifact_refs(
|
||||
pending_artifact_refs,
|
||||
result_dict,
|
||||
)
|
||||
# Clear pending refs after attaching to this message
|
||||
pending_artifact_refs.clear()
|
||||
|
||||
await self._write_assistant_transcript(
|
||||
result_dict=result_dict,
|
||||
event=event,
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
artifact_refs=merged_refs if merged_refs else None,
|
||||
)
|
||||
|
||||
# Normalize result for other types
|
||||
@@ -230,6 +260,19 @@ class AgentRunOrchestrator:
|
||||
try:
|
||||
# Run via plugin connector
|
||||
async for result_dict in self._invoke_runner(descriptor, context):
|
||||
# Handle artifact.created - register artifact
|
||||
if result_dict.get('type') == 'artifact.created':
|
||||
await self._handle_artifact_created_query(
|
||||
result_dict=result_dict,
|
||||
query=query,
|
||||
descriptor=descriptor,
|
||||
run_id=run_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
# Pass to normalizer for logging, but don't yield to pipeline
|
||||
await self.result_normalizer.normalize(result_dict, descriptor)
|
||||
continue
|
||||
|
||||
# Handle state.updated first - consume before normalizer
|
||||
if result_dict.get('type') == 'state.updated':
|
||||
self._handle_state_updated(result_dict, query, descriptor)
|
||||
@@ -417,6 +460,101 @@ class AgentRunOrchestrator:
|
||||
)
|
||||
# Invalid scope is already logged by state_store.apply_update
|
||||
|
||||
async def _handle_artifact_created_query(
|
||||
self,
|
||||
result_dict: dict[str, typing.Any],
|
||||
query: pipeline_query.Query,
|
||||
descriptor: AgentRunnerDescriptor,
|
||||
run_id: str,
|
||||
conversation_id: str | None,
|
||||
) -> None:
|
||||
"""Handle artifact.created result in Query-based flow.
|
||||
|
||||
Legacy Query flow only registers artifact metadata/content for compatibility.
|
||||
Event log/transcript linkage is event-first only for now.
|
||||
|
||||
Args:
|
||||
result_dict: Raw result dict with type='artifact.created'
|
||||
query: Pipeline query
|
||||
descriptor: Runner descriptor
|
||||
run_id: Current run ID
|
||||
conversation_id: Conversation ID (may be None)
|
||||
|
||||
Raises:
|
||||
RunnerProtocolError: On validation failures or registration errors
|
||||
"""
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
from .artifact_store import ArtifactStore
|
||||
|
||||
data = result_dict.get('data', {})
|
||||
|
||||
# Validate run_id matches current context
|
||||
result_run_id = result_dict.get('run_id')
|
||||
if result_run_id and result_run_id != run_id:
|
||||
raise RunnerProtocolError(
|
||||
descriptor.id,
|
||||
f'artifact.created run_id mismatch: expected {run_id}, got {result_run_id}',
|
||||
)
|
||||
|
||||
# Extract artifact fields
|
||||
artifact_id = data.get('artifact_id') or str(uuid.uuid4())
|
||||
artifact_type = data.get('artifact_type')
|
||||
if not artifact_type:
|
||||
raise RunnerProtocolError(
|
||||
descriptor.id,
|
||||
'artifact.created missing required field: artifact_type',
|
||||
)
|
||||
|
||||
mime_type = data.get('mime_type')
|
||||
name = data.get('name')
|
||||
size_bytes = data.get('size_bytes')
|
||||
sha256 = data.get('sha256')
|
||||
metadata = data.get('metadata')
|
||||
content_base64 = data.get('content_base64')
|
||||
|
||||
# Decode and validate content if provided
|
||||
content: bytes | None = None
|
||||
if content_base64:
|
||||
try:
|
||||
content = base64.b64decode(content_base64, validate=True)
|
||||
except Exception as e:
|
||||
raise RunnerProtocolError(
|
||||
descriptor.id,
|
||||
f'artifact.created invalid base64 content: {e}',
|
||||
)
|
||||
|
||||
# Validate content size
|
||||
if len(content) > MAX_ARTIFACT_INLINE_BYTES:
|
||||
raise RunnerProtocolError(
|
||||
descriptor.id,
|
||||
f'artifact.created content size {len(content)} bytes exceeds limit {MAX_ARTIFACT_INLINE_BYTES} bytes',
|
||||
)
|
||||
|
||||
# Register artifact via ArtifactStore
|
||||
artifact_store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
|
||||
try:
|
||||
await artifact_store.register_artifact(
|
||||
artifact_id=artifact_id,
|
||||
artifact_type=artifact_type,
|
||||
source='runner',
|
||||
mime_type=mime_type,
|
||||
name=name,
|
||||
size_bytes=size_bytes,
|
||||
sha256=sha256,
|
||||
conversation_id=conversation_id,
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
metadata=metadata,
|
||||
content=content,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RunnerProtocolError(
|
||||
descriptor.id,
|
||||
f'artifact.created failed to register artifact: {e}',
|
||||
)
|
||||
|
||||
def _handle_state_updated_event(
|
||||
self,
|
||||
result_dict: dict[str, typing.Any],
|
||||
@@ -552,12 +690,175 @@ class AgentRunOrchestrator:
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_artifact_created(
|
||||
self,
|
||||
result_dict: dict[str, typing.Any],
|
||||
event: AgentEventEnvelope,
|
||||
run_id: str,
|
||||
runner_id: str,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Handle artifact.created result - register artifact and write EventLog.
|
||||
|
||||
Args:
|
||||
result_dict: Raw result dict with type='artifact.created'
|
||||
event: Event envelope
|
||||
run_id: Current run ID
|
||||
runner_id: Runner ID
|
||||
|
||||
Returns:
|
||||
Artifact reference dict for Transcript
|
||||
|
||||
Raises:
|
||||
RunnerProtocolError: On validation failures or registration errors
|
||||
"""
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
from .artifact_store import ArtifactStore
|
||||
from .event_log_store import EventLogStore
|
||||
|
||||
data = result_dict.get('data', {})
|
||||
|
||||
# Validate run_id matches current context
|
||||
result_run_id = result_dict.get('run_id')
|
||||
if result_run_id and result_run_id != run_id:
|
||||
raise RunnerProtocolError(
|
||||
runner_id,
|
||||
f'artifact.created run_id mismatch: expected {run_id}, got {result_run_id}',
|
||||
)
|
||||
|
||||
# Extract artifact fields
|
||||
artifact_id = data.get('artifact_id') or str(uuid.uuid4())
|
||||
artifact_type = data.get('artifact_type')
|
||||
if not artifact_type:
|
||||
raise RunnerProtocolError(
|
||||
runner_id,
|
||||
'artifact.created missing required field: artifact_type',
|
||||
)
|
||||
|
||||
mime_type = data.get('mime_type')
|
||||
name = data.get('name')
|
||||
size_bytes = data.get('size_bytes')
|
||||
sha256 = data.get('sha256')
|
||||
metadata = data.get('metadata')
|
||||
content_base64 = data.get('content_base64')
|
||||
|
||||
# Decode and validate content if provided
|
||||
content: bytes | None = None
|
||||
if content_base64:
|
||||
try:
|
||||
content = base64.b64decode(content_base64, validate=True)
|
||||
except Exception as e:
|
||||
raise RunnerProtocolError(
|
||||
runner_id,
|
||||
f'artifact.created invalid base64 content: {e}',
|
||||
)
|
||||
|
||||
# Validate content size
|
||||
if len(content) > MAX_ARTIFACT_INLINE_BYTES:
|
||||
raise RunnerProtocolError(
|
||||
runner_id,
|
||||
f'artifact.created content size {len(content)} bytes exceeds limit {MAX_ARTIFACT_INLINE_BYTES} bytes',
|
||||
)
|
||||
|
||||
# Register artifact via ArtifactStore
|
||||
artifact_store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
|
||||
try:
|
||||
registered_id = await artifact_store.register_artifact(
|
||||
artifact_id=artifact_id,
|
||||
artifact_type=artifact_type,
|
||||
source='runner',
|
||||
mime_type=mime_type,
|
||||
name=name,
|
||||
size_bytes=size_bytes,
|
||||
sha256=sha256,
|
||||
conversation_id=event.conversation_id,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
bot_id=event.bot_id,
|
||||
workspace_id=event.workspace_id,
|
||||
metadata=metadata,
|
||||
content=content,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RunnerProtocolError(
|
||||
runner_id,
|
||||
f'artifact.created failed to register artifact: {e}',
|
||||
)
|
||||
|
||||
# Write to EventLog
|
||||
event_log_store = EventLogStore(self.ap.persistence_mgr.get_db_engine())
|
||||
await event_log_store.append_event(
|
||||
event_id=str(uuid.uuid4()),
|
||||
event_type='artifact.created',
|
||||
source='runner',
|
||||
bot_id=event.bot_id,
|
||||
workspace_id=event.workspace_id,
|
||||
conversation_id=event.conversation_id,
|
||||
thread_id=event.thread_id,
|
||||
actor_type=event.actor.actor_type if event.actor else None,
|
||||
actor_id=event.actor.actor_id if event.actor else None,
|
||||
actor_name=event.actor.actor_name if event.actor else None,
|
||||
input_summary=f'Artifact created: {artifact_type}',
|
||||
input_json={
|
||||
'artifact_id': registered_id,
|
||||
'artifact_type': artifact_type,
|
||||
'mime_type': mime_type,
|
||||
'name': name,
|
||||
'size_bytes': size_bytes,
|
||||
},
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
)
|
||||
|
||||
# Return artifact ref for Transcript
|
||||
return {
|
||||
'artifact_id': registered_id,
|
||||
'artifact_type': artifact_type,
|
||||
'mime_type': mime_type,
|
||||
'name': name,
|
||||
}
|
||||
|
||||
def _merge_artifact_refs(
|
||||
self,
|
||||
pending_refs: list[dict[str, typing.Any]],
|
||||
result_dict: dict[str, typing.Any],
|
||||
) -> list[dict[str, typing.Any]]:
|
||||
"""Merge pending artifact refs with message's own refs, deduplicating by artifact_id.
|
||||
|
||||
Args:
|
||||
pending_refs: Artifact refs accumulated from artifact.created events
|
||||
result_dict: Result dict that may contain message with artifact_refs
|
||||
|
||||
Returns:
|
||||
Merged and deduplicated list of artifact refs
|
||||
"""
|
||||
# Start with pending refs
|
||||
merged = list(pending_refs)
|
||||
seen_ids = {ref.get('artifact_id') for ref in pending_refs if ref.get('artifact_id')}
|
||||
|
||||
# Extract refs from message data if present
|
||||
data = result_dict.get('data', {})
|
||||
message = data.get('message', {})
|
||||
message_refs = message.get('artifact_refs', [])
|
||||
|
||||
if isinstance(message_refs, list):
|
||||
for ref in message_refs:
|
||||
if isinstance(ref, dict):
|
||||
artifact_id = ref.get('artifact_id')
|
||||
if artifact_id and artifact_id not in seen_ids:
|
||||
merged.append(ref)
|
||||
seen_ids.add(artifact_id)
|
||||
|
||||
return merged
|
||||
|
||||
async def _write_assistant_transcript(
|
||||
self,
|
||||
result_dict: dict[str, typing.Any],
|
||||
event: AgentEventEnvelope,
|
||||
run_id: str,
|
||||
runner_id: str,
|
||||
artifact_refs: list[dict[str, typing.Any]] | None = None,
|
||||
) -> None:
|
||||
"""Write assistant message to Transcript.
|
||||
|
||||
@@ -566,6 +867,7 @@ class AgentRunOrchestrator:
|
||||
event: Original event envelope
|
||||
run_id: Run ID
|
||||
runner_id: Runner ID
|
||||
artifact_refs: Optional artifact references to include
|
||||
"""
|
||||
import uuid
|
||||
|
||||
@@ -601,6 +903,7 @@ class AgentRunOrchestrator:
|
||||
role='assistant',
|
||||
content=content,
|
||||
content_json=content_json,
|
||||
artifact_refs=artifact_refs,
|
||||
thread_id=event.thread_id,
|
||||
item_type='message',
|
||||
run_id=run_id,
|
||||
|
||||
@@ -143,6 +143,15 @@ class AgentResultNormalizer:
|
||||
)
|
||||
return None
|
||||
|
||||
elif result_type == 'artifact.created':
|
||||
# Log for telemetry, consumed by orchestrator
|
||||
artifact_id = data.get('artifact_id', 'unknown')
|
||||
artifact_type = data.get('artifact_type', 'unknown')
|
||||
self.ap.logger.debug(
|
||||
f'Runner {descriptor.id} artifact.created logged: artifact_id={artifact_id}, type={artifact_type}'
|
||||
)
|
||||
return None
|
||||
|
||||
else:
|
||||
# Unknown type - warn and ignore.
|
||||
self.ap.logger.warning(
|
||||
|
||||
860
tests/unit_tests/agent/test_orchestrator_artifact.py
Normal file
860
tests/unit_tests/agent/test_orchestrator_artifact.py
Normal file
@@ -0,0 +1,860 @@
|
||||
"""Tests for artifact.created handling in orchestrator."""
|
||||
import pytest
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import uuid
|
||||
|
||||
from langbot.pkg.agent.runner.orchestrator import (
|
||||
AgentRunOrchestrator,
|
||||
MAX_ARTIFACT_INLINE_BYTES,
|
||||
)
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding
|
||||
from langbot.pkg.agent.runner.errors import RunnerProtocolError
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext
|
||||
from langbot.pkg.core import app
|
||||
|
||||
|
||||
class TestArtifactCreatedValidation:
|
||||
"""Test artifact.created validation and protocol errors."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
ap = MagicMock(spec=app.Application)
|
||||
ap.logger = MagicMock()
|
||||
ap.plugin_connector = MagicMock()
|
||||
ap.plugin_connector.is_enable_plugin = True
|
||||
ap.persistence_mgr = MagicMock()
|
||||
ap.persistence_mgr.get_db_engine = MagicMock()
|
||||
return ap
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(self):
|
||||
"""Create mock registry."""
|
||||
registry = MagicMock()
|
||||
registry.get = AsyncMock()
|
||||
return registry
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event(self):
|
||||
"""Create mock event envelope."""
|
||||
event = MagicMock(spec=AgentEventEnvelope)
|
||||
event.event_id = str(uuid.uuid4())
|
||||
event.event_type = 'message.received'
|
||||
event.source = 'test'
|
||||
event.bot_id = str(uuid.uuid4())
|
||||
event.workspace_id = str(uuid.uuid4())
|
||||
event.conversation_id = str(uuid.uuid4())
|
||||
event.thread_id = None
|
||||
event.event_time = 1700000000
|
||||
event.actor = MagicMock(spec=ActorContext)
|
||||
event.actor.actor_type = 'user'
|
||||
event.actor.actor_id = 'user-123'
|
||||
event.actor.actor_name = 'Test User'
|
||||
event.subject = None
|
||||
event.input = MagicMock(spec=AgentInput)
|
||||
event.input.text = 'Hello'
|
||||
event.input.contents = []
|
||||
event.input.attachments = []
|
||||
return event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_id_mismatch_raises_protocol_error(
|
||||
self, mock_app, mock_registry, mock_event
|
||||
):
|
||||
"""Test that run_id mismatch raises RunnerProtocolError."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
wrong_run_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': wrong_run_id,
|
||||
'data': {
|
||||
'artifact_type': 'image',
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await orchestrator._handle_artifact_created(
|
||||
result_dict=result_dict,
|
||||
event=mock_event,
|
||||
run_id=run_id,
|
||||
runner_id='test-runner',
|
||||
)
|
||||
|
||||
assert 'run_id mismatch' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_artifact_type_raises_protocol_error(
|
||||
self, mock_app, mock_registry, mock_event
|
||||
):
|
||||
"""Test that missing artifact_type raises RunnerProtocolError."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_id': str(uuid.uuid4()),
|
||||
# missing artifact_type
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await orchestrator._handle_artifact_created(
|
||||
result_dict=result_dict,
|
||||
event=mock_event,
|
||||
run_id=run_id,
|
||||
runner_id='test-runner',
|
||||
)
|
||||
|
||||
assert 'missing required field' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_base64_raises_protocol_error(
|
||||
self, mock_app, mock_registry, mock_event
|
||||
):
|
||||
"""Test that invalid base64 raises RunnerProtocolError."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_type': 'image',
|
||||
'content_base64': '!!!invalid-base64!!!',
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await orchestrator._handle_artifact_created(
|
||||
result_dict=result_dict,
|
||||
event=mock_event,
|
||||
run_id=run_id,
|
||||
runner_id='test-runner',
|
||||
)
|
||||
|
||||
assert 'invalid base64' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_content_raises_protocol_error(
|
||||
self, mock_app, mock_registry, mock_event
|
||||
):
|
||||
"""Test that content exceeding limit raises RunnerProtocolError."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
# Create content larger than limit
|
||||
oversized_content = b'x' * (MAX_ARTIFACT_INLINE_BYTES + 1)
|
||||
content_base64 = base64.b64encode(oversized_content).decode('utf-8')
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_type': 'image',
|
||||
'content_base64': content_base64,
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await orchestrator._handle_artifact_created(
|
||||
result_dict=result_dict,
|
||||
event=mock_event,
|
||||
run_id=run_id,
|
||||
runner_id='test-runner',
|
||||
)
|
||||
|
||||
assert 'exceeds limit' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_artifact_store_failure_raises_protocol_error(
|
||||
self, mock_app, mock_registry, mock_event
|
||||
):
|
||||
"""Test that ArtifactStore failure raises RunnerProtocolError."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_type': 'image',
|
||||
},
|
||||
}
|
||||
|
||||
with patch('langbot.pkg.agent.runner.artifact_store.ArtifactStore') as MockArtifactStore:
|
||||
mock_artifact_store = MagicMock()
|
||||
mock_artifact_store.register_artifact = AsyncMock(
|
||||
side_effect=Exception('DB connection failed')
|
||||
)
|
||||
MockArtifactStore.return_value = mock_artifact_store
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await orchestrator._handle_artifact_created(
|
||||
result_dict=result_dict,
|
||||
event=mock_event,
|
||||
run_id=run_id,
|
||||
runner_id='test-runner',
|
||||
)
|
||||
|
||||
assert 'failed to register artifact' in str(exc_info.value)
|
||||
|
||||
|
||||
class TestArtifactCreatedSuccess:
|
||||
"""Test successful artifact.created handling."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
ap = MagicMock(spec=app.Application)
|
||||
ap.logger = MagicMock()
|
||||
ap.plugin_connector = MagicMock()
|
||||
ap.plugin_connector.is_enable_plugin = True
|
||||
ap.persistence_mgr = MagicMock()
|
||||
ap.persistence_mgr.get_db_engine = MagicMock()
|
||||
return ap
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(self):
|
||||
"""Create mock registry."""
|
||||
registry = MagicMock()
|
||||
registry.get = AsyncMock()
|
||||
return registry
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event(self):
|
||||
"""Create mock event envelope."""
|
||||
event = MagicMock(spec=AgentEventEnvelope)
|
||||
event.event_id = str(uuid.uuid4())
|
||||
event.event_type = 'message.received'
|
||||
event.source = 'test'
|
||||
event.bot_id = str(uuid.uuid4())
|
||||
event.workspace_id = str(uuid.uuid4())
|
||||
event.conversation_id = str(uuid.uuid4())
|
||||
event.thread_id = None
|
||||
event.event_time = 1700000000
|
||||
event.actor = MagicMock(spec=ActorContext)
|
||||
event.actor.actor_type = 'user'
|
||||
event.actor.actor_id = 'user-123'
|
||||
event.actor.actor_name = 'Test User'
|
||||
event.subject = None
|
||||
return event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_artifact_created_registers_artifact(
|
||||
self, mock_app, mock_registry, mock_event
|
||||
):
|
||||
"""Test that artifact.created registers artifact via ArtifactStore."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
runner_id = 'test-runner'
|
||||
|
||||
# Create artifact.created result
|
||||
content = b'test artifact content'
|
||||
content_base64 = base64.b64encode(content).decode('utf-8')
|
||||
artifact_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'image',
|
||||
'mime_type': 'image/png',
|
||||
'name': 'test.png',
|
||||
'size_bytes': len(content),
|
||||
'content_base64': content_base64,
|
||||
},
|
||||
}
|
||||
|
||||
with patch('langbot.pkg.agent.runner.artifact_store.ArtifactStore') as MockArtifactStore:
|
||||
with patch('langbot.pkg.agent.runner.event_log_store.EventLogStore') as MockEventLogStore:
|
||||
mock_artifact_store = MagicMock()
|
||||
mock_artifact_store.register_artifact = AsyncMock(return_value=artifact_id)
|
||||
MockArtifactStore.return_value = mock_artifact_store
|
||||
|
||||
mock_event_log_store = MagicMock()
|
||||
mock_event_log_store.append_event = AsyncMock()
|
||||
MockEventLogStore.return_value = mock_event_log_store
|
||||
|
||||
# Call _handle_artifact_created
|
||||
result = await orchestrator._handle_artifact_created(
|
||||
result_dict=result_dict,
|
||||
event=mock_event,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
)
|
||||
|
||||
# Verify artifact was registered
|
||||
mock_artifact_store.register_artifact.assert_called_once()
|
||||
call_kwargs = mock_artifact_store.register_artifact.call_args.kwargs
|
||||
assert call_kwargs['artifact_id'] == artifact_id
|
||||
assert call_kwargs['artifact_type'] == 'image'
|
||||
assert call_kwargs['mime_type'] == 'image/png'
|
||||
assert call_kwargs['name'] == 'test.png'
|
||||
assert call_kwargs['content'] == content
|
||||
assert call_kwargs['conversation_id'] == mock_event.conversation_id
|
||||
assert call_kwargs['run_id'] == run_id
|
||||
assert call_kwargs['runner_id'] == runner_id
|
||||
|
||||
# Verify EventLog was written
|
||||
mock_event_log_store.append_event.assert_called_once()
|
||||
event_kwargs = mock_event_log_store.append_event.call_args.kwargs
|
||||
assert event_kwargs['event_type'] == 'artifact.created'
|
||||
assert event_kwargs['run_id'] == run_id
|
||||
|
||||
# Verify artifact ref returned
|
||||
assert result is not None
|
||||
assert result['artifact_id'] == artifact_id
|
||||
assert result['artifact_type'] == 'image'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_artifact_created_metadata_only(
|
||||
self, mock_app, mock_registry, mock_event
|
||||
):
|
||||
"""Test artifact.created without content (metadata-only)."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
artifact_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'file',
|
||||
'mime_type': 'application/pdf',
|
||||
'name': 'document.pdf',
|
||||
'size_bytes': 1024,
|
||||
'sha256': 'abc123',
|
||||
'metadata': {'source': 'external'},
|
||||
},
|
||||
}
|
||||
|
||||
with patch('langbot.pkg.agent.runner.artifact_store.ArtifactStore') as MockArtifactStore:
|
||||
with patch('langbot.pkg.agent.runner.event_log_store.EventLogStore') as MockEventLogStore:
|
||||
mock_artifact_store = MagicMock()
|
||||
mock_artifact_store.register_artifact = AsyncMock(return_value=artifact_id)
|
||||
MockArtifactStore.return_value = mock_artifact_store
|
||||
|
||||
mock_event_log_store = MagicMock()
|
||||
mock_event_log_store.append_event = AsyncMock()
|
||||
MockEventLogStore.return_value = mock_event_log_store
|
||||
|
||||
result = await orchestrator._handle_artifact_created(
|
||||
result_dict=result_dict,
|
||||
event=mock_event,
|
||||
run_id=run_id,
|
||||
runner_id='test-runner',
|
||||
)
|
||||
|
||||
# Verify artifact was registered without content
|
||||
call_kwargs = mock_artifact_store.register_artifact.call_args.kwargs
|
||||
assert call_kwargs['content'] is None
|
||||
assert call_kwargs['sha256'] == 'abc123'
|
||||
assert call_kwargs['metadata'] == {'source': 'external'}
|
||||
|
||||
assert result is not None
|
||||
assert result['artifact_id'] == artifact_id
|
||||
|
||||
|
||||
class TestArtifactRefsLifecycle:
|
||||
"""Test artifact refs lifecycle in event-first flow."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
ap = MagicMock(spec=app.Application)
|
||||
ap.logger = MagicMock()
|
||||
ap.plugin_connector = MagicMock()
|
||||
ap.plugin_connector.is_enable_plugin = True
|
||||
ap.persistence_mgr = MagicMock()
|
||||
ap.persistence_mgr.get_db_engine = MagicMock()
|
||||
return ap
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(self):
|
||||
"""Create mock registry."""
|
||||
registry = MagicMock()
|
||||
registry.get = AsyncMock()
|
||||
return registry
|
||||
|
||||
def test_merge_artifact_refs_deduplicates(
|
||||
self, mock_app, mock_registry
|
||||
):
|
||||
"""Test that _merge_artifact_refs deduplicates by artifact_id."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
|
||||
pending_refs = [
|
||||
{'artifact_id': 'artifact-1', 'artifact_type': 'image'},
|
||||
{'artifact_id': 'artifact-2', 'artifact_type': 'file'},
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.completed',
|
||||
'data': {
|
||||
'message': {
|
||||
'content': 'Hello',
|
||||
'artifact_refs': [
|
||||
{'artifact_id': 'artifact-2', 'artifact_type': 'file'}, # duplicate
|
||||
{'artifact_id': 'artifact-3', 'artifact_type': 'voice'},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
merged = orchestrator._merge_artifact_refs(pending_refs, result_dict)
|
||||
|
||||
# Should have 3 unique artifacts
|
||||
assert len(merged) == 3
|
||||
artifact_ids = {ref['artifact_id'] for ref in merged}
|
||||
assert artifact_ids == {'artifact-1', 'artifact-2', 'artifact-3'}
|
||||
|
||||
def test_merge_artifact_refs_empty_pending(
|
||||
self, mock_app, mock_registry
|
||||
):
|
||||
"""Test merge with empty pending refs."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
|
||||
pending_refs = []
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.completed',
|
||||
'data': {
|
||||
'message': {
|
||||
'content': 'Hello',
|
||||
'artifact_refs': [
|
||||
{'artifact_id': 'artifact-1', 'artifact_type': 'image'},
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
merged = orchestrator._merge_artifact_refs(pending_refs, result_dict)
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0]['artifact_id'] == 'artifact-1'
|
||||
|
||||
def test_merge_artifact_refs_empty_message_refs(
|
||||
self, mock_app, mock_registry
|
||||
):
|
||||
"""Test merge with no message artifact_refs."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
|
||||
pending_refs = [
|
||||
{'artifact_id': 'artifact-1', 'artifact_type': 'image'},
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.completed',
|
||||
'data': {
|
||||
'message': {
|
||||
'content': 'Hello',
|
||||
# no artifact_refs
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
merged = orchestrator._merge_artifact_refs(pending_refs, result_dict)
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0]['artifact_id'] == 'artifact-1'
|
||||
|
||||
|
||||
class TestArtifactCreatedQueryFlow:
|
||||
"""Test artifact.created handling in legacy Query-based flow."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
ap = MagicMock(spec=app.Application)
|
||||
ap.logger = MagicMock()
|
||||
ap.plugin_connector = MagicMock()
|
||||
ap.plugin_connector.is_enable_plugin = True
|
||||
ap.persistence_mgr = MagicMock()
|
||||
ap.persistence_mgr.get_db_engine = MagicMock()
|
||||
return ap
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(self):
|
||||
"""Create mock registry."""
|
||||
registry = MagicMock()
|
||||
registry.get = AsyncMock()
|
||||
return registry
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query(self):
|
||||
"""Create mock Query."""
|
||||
from langbot_plugin.api.entities.builtin.pipeline import query as pipeline_query
|
||||
query = MagicMock(spec=pipeline_query.Query)
|
||||
query.query_id = str(uuid.uuid4())
|
||||
query.pipeline_config = {'runner': {'id': 'test-runner'}}
|
||||
query.variables = {}
|
||||
return query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_flow_run_id_mismatch_raises_protocol_error(
|
||||
self, mock_app, mock_registry, mock_query
|
||||
):
|
||||
"""Test run_id mismatch in Query flow raises RunnerProtocolError."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
wrong_run_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': wrong_run_id,
|
||||
'data': {'artifact_type': 'image'},
|
||||
}
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.id = 'test-runner'
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await orchestrator._handle_artifact_created_query(
|
||||
result_dict=result_dict,
|
||||
query=mock_query,
|
||||
descriptor=mock_descriptor,
|
||||
run_id=run_id,
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
assert 'run_id mismatch' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_flow_invalid_base64_raises_protocol_error(
|
||||
self, mock_app, mock_registry, mock_query
|
||||
):
|
||||
"""Test invalid base64 in Query flow raises RunnerProtocolError."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_type': 'image',
|
||||
'content_base64': '!!!invalid!!!',
|
||||
},
|
||||
}
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.id = 'test-runner'
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await orchestrator._handle_artifact_created_query(
|
||||
result_dict=result_dict,
|
||||
query=mock_query,
|
||||
descriptor=mock_descriptor,
|
||||
run_id=run_id,
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
assert 'invalid base64' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_flow_missing_artifact_type_raises_protocol_error(
|
||||
self, mock_app, mock_registry, mock_query
|
||||
):
|
||||
"""Test missing artifact_type in Query flow raises RunnerProtocolError."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {}, # missing artifact_type
|
||||
}
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.id = 'test-runner'
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await orchestrator._handle_artifact_created_query(
|
||||
result_dict=result_dict,
|
||||
query=mock_query,
|
||||
descriptor=mock_descriptor,
|
||||
run_id=run_id,
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
assert 'missing required field' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_flow_oversized_content_raises_protocol_error(
|
||||
self, mock_app, mock_registry, mock_query
|
||||
):
|
||||
"""Test oversized content in Query flow raises RunnerProtocolError."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
oversized_content = b'x' * (MAX_ARTIFACT_INLINE_BYTES + 1)
|
||||
content_base64 = base64.b64encode(oversized_content).decode('utf-8')
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_type': 'image',
|
||||
'content_base64': content_base64,
|
||||
},
|
||||
}
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.id = 'test-runner'
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await orchestrator._handle_artifact_created_query(
|
||||
result_dict=result_dict,
|
||||
query=mock_query,
|
||||
descriptor=mock_descriptor,
|
||||
run_id=run_id,
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
assert 'exceeds limit' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_flow_register_success(
|
||||
self, mock_app, mock_registry, mock_query
|
||||
):
|
||||
"""Test successful artifact registration in Query flow."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
conversation_id = str(uuid.uuid4())
|
||||
artifact_id = str(uuid.uuid4())
|
||||
|
||||
content = b'test content'
|
||||
content_base64 = base64.b64encode(content).decode('utf-8')
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'voice',
|
||||
'mime_type': 'audio/mp3',
|
||||
'content_base64': content_base64,
|
||||
},
|
||||
}
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.id = 'test-runner'
|
||||
|
||||
with patch('langbot.pkg.agent.runner.artifact_store.ArtifactStore') as MockArtifactStore:
|
||||
mock_artifact_store = MagicMock()
|
||||
mock_artifact_store.register_artifact = AsyncMock(return_value=artifact_id)
|
||||
MockArtifactStore.return_value = mock_artifact_store
|
||||
|
||||
await orchestrator._handle_artifact_created_query(
|
||||
result_dict=result_dict,
|
||||
query=mock_query,
|
||||
descriptor=mock_descriptor,
|
||||
run_id=run_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
# Verify artifact was registered
|
||||
mock_artifact_store.register_artifact.assert_called_once()
|
||||
call_kwargs = mock_artifact_store.register_artifact.call_args.kwargs
|
||||
assert call_kwargs['artifact_id'] == artifact_id
|
||||
assert call_kwargs['artifact_type'] == 'voice'
|
||||
assert call_kwargs['content'] == content
|
||||
assert call_kwargs['conversation_id'] == conversation_id
|
||||
|
||||
|
||||
class TestResultNormalizerArtifactCreated:
|
||||
"""Test ResultNormalizer handling of artifact.created."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
ap = MagicMock(spec=app.Application)
|
||||
ap.logger = MagicMock()
|
||||
return ap
|
||||
|
||||
@pytest.fixture
|
||||
def mock_descriptor(self):
|
||||
"""Create mock descriptor."""
|
||||
descriptor = MagicMock()
|
||||
descriptor.id = 'test-runner'
|
||||
return descriptor
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_artifact_created_returns_none(
|
||||
self, mock_app, mock_descriptor
|
||||
):
|
||||
"""Test that artifact.created is consumed (returns None)."""
|
||||
from langbot.pkg.agent.runner.result_normalizer import AgentResultNormalizer
|
||||
|
||||
normalizer = AgentResultNormalizer(mock_app)
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': 'test-run-id',
|
||||
'data': {
|
||||
'artifact_id': 'artifact-123',
|
||||
'artifact_type': 'image',
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, mock_descriptor)
|
||||
|
||||
# Should return None (consumed)
|
||||
assert result is None
|
||||
|
||||
# Debug log should be written
|
||||
mock_app.logger.debug.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_unknown_type_warning(
|
||||
self, mock_app, mock_descriptor
|
||||
):
|
||||
"""Test that unknown result types still produce warnings."""
|
||||
from langbot.pkg.agent.runner.result_normalizer import AgentResultNormalizer
|
||||
|
||||
normalizer = AgentResultNormalizer(mock_app)
|
||||
|
||||
result_dict = {
|
||||
'type': 'unknown.type',
|
||||
'data': {},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, mock_descriptor)
|
||||
|
||||
# Should return None
|
||||
assert result is None
|
||||
|
||||
# Warning should be logged
|
||||
mock_app.logger.warning.assert_called()
|
||||
|
||||
|
||||
class TestEventLogTranscriptIntegration:
|
||||
"""Test EventLog and Transcript integration with artifact.created."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock application."""
|
||||
ap = MagicMock(spec=app.Application)
|
||||
ap.logger = MagicMock()
|
||||
ap.plugin_connector = MagicMock()
|
||||
ap.plugin_connector.is_enable_plugin = True
|
||||
ap.persistence_mgr = MagicMock()
|
||||
ap.persistence_mgr.get_db_engine = MagicMock()
|
||||
return ap
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(self):
|
||||
"""Create mock registry."""
|
||||
registry = MagicMock()
|
||||
registry.get = AsyncMock()
|
||||
return registry
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event(self):
|
||||
"""Create mock event envelope."""
|
||||
event = MagicMock(spec=AgentEventEnvelope)
|
||||
event.event_id = str(uuid.uuid4())
|
||||
event.event_type = 'message.received'
|
||||
event.source = 'test'
|
||||
event.bot_id = str(uuid.uuid4())
|
||||
event.workspace_id = str(uuid.uuid4())
|
||||
event.conversation_id = str(uuid.uuid4())
|
||||
event.thread_id = None
|
||||
event.event_time = 1700000000
|
||||
event.actor = MagicMock(spec=ActorContext)
|
||||
event.actor.actor_type = 'user'
|
||||
event.actor.actor_id = 'user-123'
|
||||
event.actor.actor_name = 'Test User'
|
||||
event.subject = None
|
||||
return event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_log_written_with_correct_event_type(
|
||||
self, mock_app, mock_registry, mock_event
|
||||
):
|
||||
"""Test that EventLog is written with event_type='artifact.created'."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
artifact_id = str(uuid.uuid4())
|
||||
|
||||
result_dict = {
|
||||
'type': 'artifact.created',
|
||||
'run_id': run_id,
|
||||
'data': {
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'image',
|
||||
},
|
||||
}
|
||||
|
||||
with patch('langbot.pkg.agent.runner.artifact_store.ArtifactStore') as MockArtifactStore:
|
||||
with patch('langbot.pkg.agent.runner.event_log_store.EventLogStore') as MockEventLogStore:
|
||||
mock_artifact_store = MagicMock()
|
||||
mock_artifact_store.register_artifact = AsyncMock(return_value=artifact_id)
|
||||
MockArtifactStore.return_value = mock_artifact_store
|
||||
|
||||
mock_event_log_store = MagicMock()
|
||||
mock_event_log_store.append_event = AsyncMock()
|
||||
MockEventLogStore.return_value = mock_event_log_store
|
||||
|
||||
await orchestrator._handle_artifact_created(
|
||||
result_dict=result_dict,
|
||||
event=mock_event,
|
||||
run_id=run_id,
|
||||
runner_id='test-runner',
|
||||
)
|
||||
|
||||
# Verify EventLog.append_event was called with correct event_type
|
||||
mock_event_log_store.append_event.assert_called_once()
|
||||
call_kwargs = mock_event_log_store.append_event.call_args.kwargs
|
||||
assert call_kwargs['event_type'] == 'artifact.created'
|
||||
assert call_kwargs['source'] == 'runner'
|
||||
assert call_kwargs['conversation_id'] == mock_event.conversation_id
|
||||
assert call_kwargs['run_id'] == run_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_transcript_receives_artifact_refs(
|
||||
self, mock_app, mock_registry, mock_event
|
||||
):
|
||||
"""Test that assistant transcript receives artifact refs from artifact.created."""
|
||||
orchestrator = AgentRunOrchestrator(mock_app, mock_registry)
|
||||
run_id = str(uuid.uuid4())
|
||||
artifact_id = str(uuid.uuid4())
|
||||
|
||||
# Create pending artifact refs
|
||||
pending_refs = [
|
||||
{'artifact_id': artifact_id, 'artifact_type': 'image', 'mime_type': 'image/png'},
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
'type': 'message.completed',
|
||||
'data': {
|
||||
'message': {
|
||||
'content': 'Here is your image',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
with patch('langbot.pkg.agent.runner.transcript_store.TranscriptStore') as MockTranscriptStore:
|
||||
mock_transcript_store = MagicMock()
|
||||
mock_transcript_store.append_transcript = AsyncMock()
|
||||
MockTranscriptStore.return_value = mock_transcript_store
|
||||
|
||||
await orchestrator._write_assistant_transcript(
|
||||
result_dict=result_dict,
|
||||
event=mock_event,
|
||||
run_id=run_id,
|
||||
runner_id='test-runner',
|
||||
artifact_refs=pending_refs,
|
||||
)
|
||||
|
||||
# Verify transcript was written with artifact_refs
|
||||
mock_transcript_store.append_transcript.assert_called_once()
|
||||
call_kwargs = mock_transcript_store.append_transcript.call_args.kwargs
|
||||
assert call_kwargs['artifact_refs'] == pending_refs
|
||||
Reference in New Issue
Block a user