diff --git a/docs/agent-runner-pluginization/HOST_SDK_INFRASTRUCTURE.md b/docs/agent-runner-pluginization/HOST_SDK_INFRASTRUCTURE.md index dadeab5b..d9ca602a 100644 --- a/docs/agent-runner-pluginization/HOST_SDK_INFRASTRUCTURE.md +++ b/docs/agent-runner-pluginization/HOST_SDK_INFRASTRUCTURE.md @@ -299,8 +299,10 @@ permissions: tools: ["detail", "call"] knowledge_bases: ["list", "retrieve"] history: ["page", "search"] + events: ["get", "page"] artifacts: ["metadata", "read"] storage: ["plugin", "workspace", "binding"] + files: ["config", "knowledge"] platform_api: [] ``` diff --git a/docs/agent-runner-pluginization/PROTOCOL_V1.md b/docs/agent-runner-pluginization/PROTOCOL_V1.md index 6608a088..8c60d3dd 100644 --- a/docs/agent-runner-pluginization/PROTOCOL_V1.md +++ b/docs/agent-runner-pluginization/PROTOCOL_V1.md @@ -121,6 +121,7 @@ class AgentRunnerPermissions(BaseModel): events: list[Literal["get", "page"]] = [] artifacts: list[Literal["metadata", "read"]] = [] storage: list[Literal["plugin", "workspace", "binding"]] = [] + files: list[Literal["config", "knowledge"]] = [] platform_api: list[str] = [] ``` @@ -370,7 +371,6 @@ class AgentRunState(BaseModel): actor: dict[str, Any] = {} subject: dict[str, Any] = {} runner: dict[str, Any] = {} - binding: dict[str, Any] = {} ``` State 是可选 host-owned snapshot。Runner 也可以完全自管状态。 @@ -382,13 +382,12 @@ class AgentResources(BaseModel): models: list[ModelResource] = [] tools: list[ToolResource] = [] knowledge_bases: list[KnowledgeBaseResource] = [] - artifacts: list[ArtifactResource] = [] + files: list[FileResource] = [] storage: StorageResource = StorageResource() - history: HistoryResource = HistoryResource() platform_capabilities: dict[str, Any] = {} ``` -资源列表是本次 run 的授权结果。Runner 只能通过 `AgentRunAPIProxy` 访问这些资源。 +资源列表是本次 run 的授权结果。History / Event / Artifact 访问通过 permissions、`ctx.context.available_apis` 和 Host 侧 run session 校验控制,不作为可枚举 resource list 暴露。Runner 只能通过 `AgentRunAPIProxy` 访问这些能力。 ## 6. Result Stream diff --git a/src/langbot/pkg/agent/runner/host_models.py b/src/langbot/pkg/agent/runner/host_models.py index 12ebf796..ffa96604 100644 --- a/src/langbot/pkg/agent/runner/host_models.py +++ b/src/langbot/pkg/agent/runner/host_models.py @@ -89,13 +89,13 @@ class ResourcePolicy(pydantic.BaseModel): """ allowed_model_uuids: list[str] | None = None - """Allowed model UUIDs. None means all authorized.""" + """Additional model UUID grants. None means no additional model grants.""" allowed_tool_names: list[str] | None = None - """Allowed tool names. None means all authorized.""" + """Additional tool name grants. None means no additional tool grants.""" allowed_kb_uuids: list[str] | None = None - """Allowed knowledge base UUIDs. None means all authorized.""" + """Additional knowledge base UUID grants. None means no additional KB grants.""" allow_plugin_storage: bool = True """Whether plugin storage is allowed.""" diff --git a/src/langbot/pkg/agent/runner/orchestrator.py b/src/langbot/pkg/agent/runner/orchestrator.py index f8bd3f03..90dafbf1 100644 --- a/src/langbot/pkg/agent/runner/orchestrator.py +++ b/src/langbot/pkg/agent/runner/orchestrator.py @@ -176,6 +176,13 @@ class AgentRunOrchestrator: runner_id=descriptor.id, ) + # Register incoming attachments so input/transcript artifact_refs are resolvable. + await self._register_input_artifacts( + event=event, + run_id=run_id, + runner_id=descriptor.id, + ) + # Write user message to Transcript if message.received if event.event_type == 'message.received' and event.conversation_id: await self._write_user_transcript( @@ -526,6 +533,97 @@ class AgentRunOrchestrator: event_time=datetime.datetime.fromtimestamp(event.event_time) if event.event_time else None, ) + async def _register_input_artifacts( + self, + event: AgentEventEnvelope, + run_id: str, + runner_id: str, + ) -> None: + """Register current-event attachments referenced by AgentInput.""" + if not event.input or not event.input.attachments: + return + + from .artifact_store import ArtifactStore + store = ArtifactStore(self.ap.persistence_mgr.get_db_engine()) + + for attachment in event.input.attachments: + data = attachment.model_dump(mode='json') if hasattr(attachment, 'model_dump') else attachment + if not isinstance(data, dict): + continue + + artifact_id = data.get('artifact_id') + artifact_type = data.get('artifact_type') or 'file' + if not artifact_id: + continue + + content, parsed_mime_type = self._decode_attachment_content(data.get('content')) + url = data.get('url') + platform_ref_id = data.get('id') + storage_key = None + storage_type = 'metadata_only' + if content is None: + if url: + storage_key = url + storage_type = 'url' + elif platform_ref_id: + storage_key = platform_ref_id + storage_type = 'platform_ref' + + metadata = { + 'input_attachment': True, + 'input_source': data.get('source') or 'platform', + } + if url: + metadata['url'] = url + if platform_ref_id: + metadata['platform_ref_id'] = platform_ref_id + + try: + await store.register_artifact( + artifact_id=artifact_id, + artifact_type=artifact_type, + source='platform', + storage_key=storage_key, + storage_type=storage_type, + mime_type=data.get('mime_type') or parsed_mime_type, + name=data.get('name'), + size_bytes=data.get('size') or (len(content) if content is not None else None), + conversation_id=event.conversation_id, + run_id=run_id, + runner_id=runner_id, + bot_id=event.bot_id, + workspace_id=event.workspace_id, + metadata=metadata, + content=content, + ) + except Exception as e: + self.ap.logger.warning( + f'Failed to register input artifact {artifact_id}: {e}' + ) + + def _decode_attachment_content( + self, + content: typing.Any, + ) -> tuple[bytes | None, str | None]: + """Decode base64 attachment content, including data URLs.""" + if not isinstance(content, str) or not content: + return None, None + + import base64 + import binascii + + mime_type = None + payload = content + if content.startswith('data:') and ',' in content: + header, payload = content.split(',', 1) + if ';base64' in header: + mime_type = header[5:].split(';', 1)[0] or None + + try: + return base64.b64decode(payload, validate=False), mime_type + except (binascii.Error, ValueError): + return None, mime_type + async def _write_user_transcript( self, event: AgentEventEnvelope, diff --git a/src/langbot/pkg/agent/runner/persistent_state_store.py b/src/langbot/pkg/agent/runner/persistent_state_store.py index 8208dd52..df4d6f8a 100644 --- a/src/langbot/pkg/agent/runner/persistent_state_store.py +++ b/src/langbot/pkg/agent/runner/persistent_state_store.py @@ -250,6 +250,8 @@ class PersistentStateStore: Used by State API handlers. """ + state_key = normalize_state_key(state_key) + async with self._db_engine.connect() as conn: result = await conn.execute( select(AgentRunnerState.value_json) @@ -282,6 +284,8 @@ class PersistentStateStore: Used by State API handlers. Context contains optional fields like bot_id, conversation_id, etc. """ + state_key = normalize_state_key(state_key) + # Validate and serialize value value_json, error = self._validate_json_value(value, logger) if error: @@ -344,6 +348,8 @@ class PersistentStateStore: Returns True if deleted, False if not found. """ + state_key = normalize_state_key(state_key) + async with self._db_engine.begin() as conn: result = await conn.execute( delete(AgentRunnerState) @@ -376,6 +382,7 @@ class PersistentStateStore: ) if prefix: + prefix = normalize_state_key(prefix) query = query.where( AgentRunnerState.state_key.like(f'{prefix}%') ) diff --git a/src/langbot/pkg/agent/runner/resource_builder.py b/src/langbot/pkg/agent/runner/resource_builder.py index d930c44a..b4df99e5 100644 --- a/src/langbot/pkg/agent/runner/resource_builder.py +++ b/src/langbot/pkg/agent/runner/resource_builder.py @@ -103,12 +103,13 @@ class AgentResourceBuilder: models: list[ModelResource] = [] seen_model_ids: set[str] = set() - # Check manifest permission model_perms = manifest_perms.get('models', []) - if 'invoke' not in model_perms and 'stream' not in model_perms: + allow_llm = 'invoke' in model_perms or 'stream' in model_perms + allow_rerank = 'rerank' in model_perms + if not allow_llm and not allow_rerank: return models - # Get model UUIDs from resource policy + # Get additional model UUID grants from resource policy. allowed_uuids = resource_policy.allowed_model_uuids # Add model resources from binding config schema @@ -117,10 +118,12 @@ class AgentResourceBuilder: seen_model_ids=seen_model_ids, descriptor=descriptor, runner_config=runner_config, + include_llm=allow_llm, + include_rerank=allow_rerank, ) # Add explicitly allowed models - if allowed_uuids: + if allowed_uuids and allow_llm: for model_uuid in allowed_uuids: await self._append_llm_model_resource(models, seen_model_ids, model_uuid) @@ -168,13 +171,13 @@ class AgentResourceBuilder: if 'list' not in kb_perms and 'retrieve' not in kb_perms: return kb_resources - # Get KB UUIDs from schema-defined config fields + # Get KB UUID grants from schema-defined config fields. kb_uuids = config_schema.extract_knowledge_base_uuids(descriptor, runner_config) - # Also check resource policy + # Also include resource policy grants. allowed_uuids = resource_policy.allowed_kb_uuids if allowed_uuids: - kb_uuids = allowed_uuids + kb_uuids = list(dict.fromkeys([*kb_uuids, *allowed_uuids])) for kb_uuid in kb_uuids: try: @@ -210,12 +213,14 @@ class AgentResourceBuilder: seen_model_ids: set[str], descriptor: AgentRunnerDescriptor, runner_config: dict[str, typing.Any], + include_llm: bool, + include_rerank: bool, ) -> None: """Authorize model-like values selected through DynamicForm fields.""" for model_type, model_uuid in config_schema.iter_config_model_refs(descriptor, runner_config): - if model_type == 'llm': + if model_type == 'llm' and include_llm: await self._append_llm_model_resource(models, seen_model_ids, model_uuid) - elif model_type == 'rerank': + elif model_type == 'rerank' and include_rerank: await self._append_rerank_model_resource(models, seen_model_ids, model_uuid) async def _append_llm_model_resource( diff --git a/src/langbot/pkg/agent/runner/transcript_store.py b/src/langbot/pkg/agent/runner/transcript_store.py index 05064525..ef18115f 100644 --- a/src/langbot/pkg/agent/runner/transcript_store.py +++ b/src/langbot/pkg/agent/runner/transcript_store.py @@ -73,9 +73,6 @@ class TranscriptStore: if content and len(content) > self.MAX_CONTENT_LENGTH: content = content[:self.MAX_CONTENT_LENGTH - 3] + "..." - # Get next sequence number for this conversation - seq = await self._get_next_seq(conversation_id) - async with self._session_factory() as session: item = Transcript( transcript_id=transcript_id, @@ -87,13 +84,15 @@ class TranscriptStore: content=content, content_json=json.dumps(content_json) if content_json else None, artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None, - seq=seq, + seq=0, run_id=run_id, runner_id=runner_id, created_at=datetime.datetime.utcnow(), metadata_json=json.dumps(metadata) if metadata else None, ) session.add(item) + await session.flush() + item.seq = item.id or await self._get_next_seq(conversation_id) await session.commit() return transcript_id @@ -253,7 +252,7 @@ class TranscriptStore: return count > 0 async def _get_next_seq(self, conversation_id: str) -> int: - """Get the next sequence number for a conversation.""" + """Fallback next sequence number for stores that cannot expose autoincrement IDs.""" async with self._session_factory() as session: result = await session.execute( sqlalchemy.select(sqlalchemy.func.max(Transcript.seq)) diff --git a/src/langbot/pkg/entity/persistence/transcript.py b/src/langbot/pkg/entity/persistence/transcript.py index 5d72340b..da0a894c 100644 --- a/src/langbot/pkg/entity/persistence/transcript.py +++ b/src/langbot/pkg/entity/persistence/transcript.py @@ -50,7 +50,7 @@ class Transcript(Base): # Sequence for cursor-based pagination seq = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, index=True) - """Sequence number within conversation (auto-increment per conversation).""" + """Monotonic cursor sequence for pagination.""" # Context run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True) diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index 312d2877..0e5cb44a 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -141,6 +141,70 @@ def _validate_artifact_access( return False, f'Artifact {operation} access denied: artifact not in session conversation and not created by this run' +async def _validate_agent_run_session( + run_id: str, + caller_plugin_identity: str | None, + ap: app.Application, + api_name: str, + permission_group: str | None = None, + permission_operation: str | None = None, +) -> Union[tuple[None, handler.ActionResponse], tuple[Any, None]]: + """Validate an AgentRunner pull API run session and optional manifest permission.""" + session_registry = get_session_registry() + session = await session_registry.get(run_id) + if not session: + return None, handler.ActionResponse.error( + message=f'Run session {run_id} not found or expired' + ) + + session_plugin_identity = session.get('plugin_identity') + if session_plugin_identity: + if not caller_plugin_identity: + return None, handler.ActionResponse.error( + message=f'caller_plugin_identity is required for run_id {run_id}' + ) + if caller_plugin_identity != session_plugin_identity: + ap.logger.warning( + f'{api_name}: caller_plugin_identity {caller_plugin_identity} ' + f'does not match session plugin_identity {session_plugin_identity}' + ) + return None, handler.ActionResponse.error( + message=f'Plugin identity mismatch for run_id {run_id}' + ) + + if permission_group and permission_operation: + permissions = session.get('permissions', {}) + allowed_operations = permissions.get(permission_group, []) + if permission_operation not in allowed_operations: + return None, handler.ActionResponse.error( + message=f'{api_name} access not authorized' + ) + + return session, None + + +def _resolve_run_conversation( + session: dict[str, Any], + requested_conversation_id: str | None, + api_name: str, +) -> tuple[str | None, handler.ActionResponse | None]: + """Resolve and enforce current-run conversation scope.""" + session_conversation_id = session.get('conversation_id') + + if requested_conversation_id: + if not session_conversation_id: + return None, handler.ActionResponse.error( + message=f'{api_name} is not available without a run conversation' + ) + if requested_conversation_id != session_conversation_id: + return None, handler.ActionResponse.error( + message=f'Conversation {requested_conversation_id} is not accessible by this run' + ) + return requested_conversation_id, None + + return session_conversation_id, None + + def _normalize_uuid_list(values: Any) -> list[str]: """Normalize a user/config supplied UUID list while preserving order.""" if not isinstance(values, list): @@ -1197,7 +1261,7 @@ class RuntimeConnectionHandler(handler.Handler): kb_id = data['kb_id'] query_text = data['query_text'] top_k = data.get('top_k', 5) - filters = data.get('filters', {}) + filters = data.get('filters') or {} run_id = data.get('run_id') # Optional: present for AgentRunner calls caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation @@ -1271,7 +1335,7 @@ class RuntimeConnectionHandler(handler.Handler): kb_id = data['kb_id'] query_text = data['query_text'] top_k = data.get('top_k', 5) - filters = data.get('filters', {}) + filters = data.get('filters') or {} run_id = data.get('run_id') # Optional: present for AgentRunner calls caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation @@ -1342,29 +1406,24 @@ class RuntimeConnectionHandler(handler.Handler): if not run_id: return handler.ActionResponse.error(message='run_id is required') - # Validate run session - session_registry = get_session_registry() - session = await session_registry.get(run_id) - if not session: - return handler.ActionResponse.error( - message=f'Run session {run_id} not found or expired' - ) + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'History page', + permission_group='history', + permission_operation='page', + ) + if error: + return error - # Validate caller plugin identity (strict: required when session has plugin_identity) - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity: - if not caller_plugin_identity: - return handler.ActionResponse.error( - message=f'caller_plugin_identity is required for run_id {run_id}' - ) - if caller_plugin_identity != session_plugin_identity: - return handler.ActionResponse.error( - message=f'Plugin identity mismatch for run_id {run_id}' - ) - - # Get conversation from session if not provided - if not conversation_id: - conversation_id = session.get('conversation_id') + conversation_id, scope_error = _resolve_run_conversation( + session, + conversation_id, + 'History page', + ) + if scope_error: + return scope_error if not conversation_id: return handler.ActionResponse.success(data={ @@ -1411,35 +1470,32 @@ class RuntimeConnectionHandler(handler.Handler): """ run_id = data.get('run_id') query_text = data.get('query', '') - filters = data.get('filters', {}) + filters = data.get('filters') or {} top_k = data.get('top_k', 10) caller_plugin_identity = data.get('caller_plugin_identity') if not run_id: return handler.ActionResponse.error(message='run_id is required') - # Validate run session - session_registry = get_session_registry() - session = await session_registry.get(run_id) - if not session: - return handler.ActionResponse.error( - message=f'Run session {run_id} not found or expired' - ) + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'History search', + permission_group='history', + permission_operation='search', + ) + if error: + return error - # Validate caller plugin identity (strict: required when session has plugin_identity) - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity: - if not caller_plugin_identity: - return handler.ActionResponse.error( - message=f'caller_plugin_identity is required for run_id {run_id}' - ) - if caller_plugin_identity != session_plugin_identity: - return handler.ActionResponse.error( - message=f'Plugin identity mismatch for run_id {run_id}' - ) - - # Get conversation from session or filters - conversation_id = filters.get('conversation_id') or session.get('conversation_id') + requested_conversation_id = filters.get('conversation_id') + conversation_id, scope_error = _resolve_run_conversation( + session, + requested_conversation_id, + 'History search', + ) + if scope_error: + return scope_error if not conversation_id: return handler.ActionResponse.success(data={ @@ -1453,10 +1509,11 @@ class RuntimeConnectionHandler(handler.Handler): store = TranscriptStore(self.ap.persistence_mgr.get_db_engine()) try: + safe_filters = {k: v for k, v in filters.items() if k != 'conversation_id'} items = await store.search_transcript( conversation_id=conversation_id, query_text=query_text, - filters=filters, + filters=safe_filters, top_k=top_k, ) @@ -1485,25 +1542,16 @@ class RuntimeConnectionHandler(handler.Handler): if not event_id: return handler.ActionResponse.error(message='event_id is required') - # Validate run session - session_registry = get_session_registry() - session = await session_registry.get(run_id) - if not session: - return handler.ActionResponse.error( - message=f'Run session {run_id} not found or expired' - ) - - # Validate caller plugin identity (strict: required when session has plugin_identity) - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity: - if not caller_plugin_identity: - return handler.ActionResponse.error( - message=f'caller_plugin_identity is required for run_id {run_id}' - ) - if caller_plugin_identity != session_plugin_identity: - return handler.ActionResponse.error( - message=f'Plugin identity mismatch for run_id {run_id}' - ) + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Event get', + permission_group='events', + permission_operation='get', + ) + if error: + return error # Get event from ..agent.runner.event_log_store import EventLogStore @@ -1516,9 +1564,12 @@ class RuntimeConnectionHandler(handler.Handler): message=f'Event {event_id} not found' ) - # Validate event is in the same conversation as the run + # Validate event is in the same conversation as the run, or was created by the same run. session_conversation_id = session.get('conversation_id') - if session_conversation_id and event.get('conversation_id') != session_conversation_id: + event_run_id = event.get('run_id') + if event_run_id and event_run_id == run_id: + return handler.ActionResponse.success(data=event) + if not session_conversation_id or event.get('conversation_id') != session_conversation_id: return handler.ActionResponse.error( message=f'Event {event_id} is not accessible by this run' ) @@ -1544,29 +1595,24 @@ class RuntimeConnectionHandler(handler.Handler): if not run_id: return handler.ActionResponse.error(message='run_id is required') - # Validate run session - session_registry = get_session_registry() - session = await session_registry.get(run_id) - if not session: - return handler.ActionResponse.error( - message=f'Run session {run_id} not found or expired' - ) + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Event page', + permission_group='events', + permission_operation='page', + ) + if error: + return error - # Validate caller plugin identity (strict: required when session has plugin_identity) - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity: - if not caller_plugin_identity: - return handler.ActionResponse.error( - message=f'caller_plugin_identity is required for run_id {run_id}' - ) - if caller_plugin_identity != session_plugin_identity: - return handler.ActionResponse.error( - message=f'Plugin identity mismatch for run_id {run_id}' - ) - - # Get conversation from session if not provided - if not conversation_id: - conversation_id = session.get('conversation_id') + conversation_id, scope_error = _resolve_run_conversation( + session, + conversation_id, + 'Event page', + ) + if scope_error: + return scope_error if not conversation_id: return handler.ActionResponse.success(data={ @@ -1620,33 +1666,16 @@ class RuntimeConnectionHandler(handler.Handler): if not artifact_id: return handler.ActionResponse.error(message='artifact_id is required') - # Validate run session - session_registry = get_session_registry() - session = await session_registry.get(run_id) - if not session: - return handler.ActionResponse.error( - message=f'Run session {run_id} not found or expired' - ) - - # Validate caller plugin identity (strict: required when session has plugin_identity) - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity: - if not caller_plugin_identity: - return handler.ActionResponse.error( - message=f'caller_plugin_identity is required for run_id {run_id}' - ) - if caller_plugin_identity != session_plugin_identity: - return handler.ActionResponse.error( - message=f'Plugin identity mismatch for run_id {run_id}' - ) - - # Check artifact permission from session.permissions (from descriptor.permissions) - permissions = session.get('permissions', {}) - artifact_permissions = permissions.get('artifacts', []) - if 'metadata' not in artifact_permissions: - return handler.ActionResponse.error( - message='Artifact metadata access not authorized' - ) + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Artifact metadata', + permission_group='artifacts', + permission_operation='metadata', + ) + if error: + return error # Get artifact metadata from ..agent.runner.artifact_store import ArtifactStore @@ -1708,33 +1737,16 @@ class RuntimeConnectionHandler(handler.Handler): if limit <= 0: return handler.ActionResponse.error(message='limit must be > 0') - # Validate run session - session_registry = get_session_registry() - session = await session_registry.get(run_id) - if not session: - return handler.ActionResponse.error( - message=f'Run session {run_id} not found or expired' - ) - - # Validate caller plugin identity (strict: required when session has plugin_identity) - session_plugin_identity = session.get('plugin_identity') - if session_plugin_identity: - if not caller_plugin_identity: - return handler.ActionResponse.error( - message=f'caller_plugin_identity is required for run_id {run_id}' - ) - if caller_plugin_identity != session_plugin_identity: - return handler.ActionResponse.error( - message=f'Plugin identity mismatch for run_id {run_id}' - ) - - # Check artifact permission from session.permissions (from descriptor.permissions) - permissions = session.get('permissions', {}) - artifact_permissions = permissions.get('artifacts', []) - if 'read' not in artifact_permissions: - return handler.ActionResponse.error( - message='Artifact read access not authorized' - ) + session, error = await _validate_agent_run_session( + run_id, + caller_plugin_identity, + self.ap, + 'Artifact read', + permission_group='artifacts', + permission_operation='read', + ) + if error: + return error # Get artifact metadata first to validate access from ..agent.runner.artifact_store import ArtifactStore diff --git a/tests/unit_tests/agent/test_history_event_api_auth.py b/tests/unit_tests/agent/test_history_event_api_auth.py new file mode 100644 index 00000000..00728fe5 --- /dev/null +++ b/tests/unit_tests/agent/test_history_event_api_auth.py @@ -0,0 +1,146 @@ +"""Tests for AgentRunner history/event pull API authorization.""" +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.ext.asyncio import create_async_engine + +from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry +from langbot.pkg.plugin.handler import RuntimeConnectionHandler +from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction + +from .conftest import make_resources + + +class FakeConnection: + pass + + +class FakeApplication: + def __init__(self, db_engine): + self.logger = MagicMock() + self.persistence_mgr = MagicMock() + self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine) + + +@pytest.fixture +def session_registry(monkeypatch): + registry = AgentRunSessionRegistry() + monkeypatch.setattr( + 'langbot.pkg.agent.runner.session_registry._global_registry', + registry, + ) + return registry + + +@pytest.fixture +async def db_engine(): + engine = create_async_engine('sqlite+aiosqlite:///:memory:') + yield engine + await engine.dispose() + + +def _handler(db_engine, session_registry): + async def fake_disconnect(): + return True + + fake_app = FakeApplication(db_engine) + return RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app) + + +async def _register_session( + session_registry, + *, + run_id='run_1', + conversation_id='conv_1', + permissions=None, +): + await session_registry.register( + run_id=run_id, + runner_id='plugin:test/runner/default', + query_id=None, + plugin_identity='test/runner', + resources=make_resources(), + conversation_id=conversation_id, + permissions=permissions or {}, + ) + + +@pytest.mark.asyncio +async def test_history_page_requires_manifest_permission(session_registry, db_engine): + await _register_session(session_registry, permissions={'history': []}) + handler = _handler(db_engine, session_registry) + history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value] + + result = await history_page({ + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + }) + + assert result.code != 0 + assert 'not authorized' in result.message.lower() + + +@pytest.mark.asyncio +async def test_history_page_rejects_cross_conversation(session_registry, db_engine): + await _register_session(session_registry, permissions={'history': ['page']}) + handler = _handler(db_engine, session_registry) + history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value] + + result = await history_page({ + 'run_id': 'run_1', + 'conversation_id': 'conv_other', + 'caller_plugin_identity': 'test/runner', + }) + + assert result.code != 0 + assert 'not accessible' in result.message.lower() + + +@pytest.mark.asyncio +async def test_history_search_rejects_filter_conversation_override(session_registry, db_engine): + await _register_session(session_registry, permissions={'history': ['search']}) + handler = _handler(db_engine, session_registry) + history_search = handler.actions[PluginToRuntimeAction.HISTORY_SEARCH.value] + + result = await history_search({ + 'run_id': 'run_1', + 'query': 'hello', + 'filters': {'conversation_id': 'conv_other'}, + 'caller_plugin_identity': 'test/runner', + }) + + assert result.code != 0 + assert 'not accessible' in result.message.lower() + + +@pytest.mark.asyncio +async def test_event_page_requires_manifest_permission(session_registry, db_engine): + await _register_session(session_registry, permissions={'events': []}) + handler = _handler(db_engine, session_registry) + event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value] + + result = await event_page({ + 'run_id': 'run_1', + 'caller_plugin_identity': 'test/runner', + }) + + assert result.code != 0 + assert 'not authorized' in result.message.lower() + + +@pytest.mark.asyncio +async def test_event_page_rejects_cross_conversation(session_registry, db_engine): + await _register_session(session_registry, permissions={'events': ['page']}) + handler = _handler(db_engine, session_registry) + event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value] + + result = await event_page({ + 'run_id': 'run_1', + 'conversation_id': 'conv_other', + 'caller_plugin_identity': 'test/runner', + }) + + assert result.code != 0 + assert 'not accessible' in result.message.lower() diff --git a/tests/unit_tests/agent/test_resource_builder.py b/tests/unit_tests/agent/test_resource_builder.py index 1576a430..48d94032 100644 --- a/tests/unit_tests/agent/test_resource_builder.py +++ b/tests/unit_tests/agent/test_resource_builder.py @@ -18,6 +18,7 @@ def make_descriptor( *, permissions: dict | None = None, config_schema: list[dict] | None = None, + capabilities: dict | None = None, ) -> AgentRunnerDescriptor: return AgentRunnerDescriptor( id=RUNNER_ID, @@ -26,6 +27,7 @@ def make_descriptor( plugin_author='test', plugin_name='runner', runner_name='default', + capabilities=capabilities or {}, permissions=permissions or {'models': ['invoke', 'stream']}, config_schema=config_schema or [], ) @@ -99,6 +101,7 @@ async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=get_model_by_uuid) app.model_mgr.get_rerank_model_by_uuid = AsyncMock(side_effect=get_rerank_model_by_uuid) descriptor = make_descriptor( + permissions={'models': ['invoke', 'stream', 'rerank']}, config_schema=[ {'name': 'model', 'type': 'model-fallback-selector'}, {'name': 'aux-model', 'type': 'llm-model-selector'}, @@ -145,6 +148,33 @@ async def test_build_models_still_honors_manifest_permissions(app): app.model_mgr.get_rerank_model_by_uuid.assert_not_awaited() +@pytest.mark.asyncio +async def test_build_models_authorizes_rerank_only_runner(app): + """A rerank-only runner should receive config-selected rerank models.""" + app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model()) + app.model_mgr.get_rerank_model_by_uuid = AsyncMock( + return_value=make_model(model_type='rerank', provider='rerank-provider') + ) + descriptor = make_descriptor( + permissions={'models': ['rerank']}, + config_schema=[ + {'name': 'model', 'type': 'llm-model-selector'}, + {'name': 'rerank-model', 'type': 'rerank-model-selector'}, + ], + ) + query = make_query({ + 'model': 'llm', + 'rerank-model': 'rerank', + }) + + resources = await build_resources(app, query, descriptor) + + assert resources['models'] == [ + {'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider'}, + ] + app.model_mgr.get_model_by_uuid.assert_not_awaited() + + @pytest.mark.asyncio async def test_build_models_deduplicates_query_and_config_models(app): """A model selected by both preproc and runner config should appear once.""" @@ -197,3 +227,37 @@ async def test_build_tools_authorizes_query_declared_tools(app): 'description': None, }, ] + + +@pytest.mark.asyncio +async def test_build_knowledge_bases_unions_config_and_policy_grants(app): + descriptor = make_descriptor( + capabilities={'knowledge_retrieval': True}, + permissions={ + 'models': [], + 'knowledge_bases': ['retrieve'], + }, + config_schema=[ + {'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'}, + ], + ) + query = make_query( + {'knowledge-bases': ['kb_config']}, + variables={'_knowledge_base_uuids': ['kb_policy']}, + ) + + async def get_kb(kb_uuid): + return SimpleNamespace( + uuid=kb_uuid, + get_name=lambda: f'name-{kb_uuid}', + knowledge_base_entity=SimpleNamespace(kb_type='default'), + ) + + app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(side_effect=get_kb) + + resources = await build_resources(app, query, descriptor) + + assert resources['knowledge_bases'] == [ + {'kb_id': 'kb_config', 'kb_name': 'name-kb_config', 'kb_type': 'default'}, + {'kb_id': 'kb_policy', 'kb_name': 'name-kb_policy', 'kb_type': 'default'}, + ] diff --git a/tests/unit_tests/agent/test_state_store.py b/tests/unit_tests/agent/test_state_store.py index 88e9d5fe..05166110 100644 --- a/tests/unit_tests/agent/test_state_store.py +++ b/tests/unit_tests/agent/test_state_store.py @@ -213,6 +213,30 @@ class TestPersistentStateStore: snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor) assert snapshot['conversation']['test_key'] == {'nested': 'value'} + @pytest.mark.asyncio + async def test_state_api_methods_normalize_public_key_aliases(self, persistent_store): + scope_key = 'conversation:runner:binding:conv_001' + + success, error = await persistent_store.state_set( + scope_key=scope_key, + state_key='conversation_id', + value='conv_001', + runner_id='plugin:test/my-runner/default', + binding_identity='binding_001', + scope='conversation', + ) + + assert success is True + assert error is None + assert await persistent_store.state_get(scope_key, 'external.conversation_id') == 'conv_001' + assert await persistent_store.state_get(scope_key, 'conversation_id') == 'conv_001' + + keys, _ = await persistent_store.state_list(scope_key, prefix='conversation_id') + assert keys == ['external.conversation_id'] + + assert await persistent_store.state_delete(scope_key, 'conversation_id') is True + assert await persistent_store.state_get(scope_key, 'external.conversation_id') is None + @pytest.mark.asyncio async def test_binding_isolation(self, persistent_store): descriptor = make_descriptor()