mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 07:54:19 +00:00
fix: enforce agent run API permissions
This commit is contained in:
committed by
huanghuoguoguo
parent
c296c187f4
commit
0b9778abd9
@@ -299,8 +299,10 @@ permissions:
|
|||||||
tools: ["detail", "call"]
|
tools: ["detail", "call"]
|
||||||
knowledge_bases: ["list", "retrieve"]
|
knowledge_bases: ["list", "retrieve"]
|
||||||
history: ["page", "search"]
|
history: ["page", "search"]
|
||||||
|
events: ["get", "page"]
|
||||||
artifacts: ["metadata", "read"]
|
artifacts: ["metadata", "read"]
|
||||||
storage: ["plugin", "workspace", "binding"]
|
storage: ["plugin", "workspace", "binding"]
|
||||||
|
files: ["config", "knowledge"]
|
||||||
platform_api: []
|
platform_api: []
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class AgentRunnerPermissions(BaseModel):
|
|||||||
events: list[Literal["get", "page"]] = []
|
events: list[Literal["get", "page"]] = []
|
||||||
artifacts: list[Literal["metadata", "read"]] = []
|
artifacts: list[Literal["metadata", "read"]] = []
|
||||||
storage: list[Literal["plugin", "workspace", "binding"]] = []
|
storage: list[Literal["plugin", "workspace", "binding"]] = []
|
||||||
|
files: list[Literal["config", "knowledge"]] = []
|
||||||
platform_api: list[str] = []
|
platform_api: list[str] = []
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -370,7 +371,6 @@ class AgentRunState(BaseModel):
|
|||||||
actor: dict[str, Any] = {}
|
actor: dict[str, Any] = {}
|
||||||
subject: dict[str, Any] = {}
|
subject: dict[str, Any] = {}
|
||||||
runner: dict[str, Any] = {}
|
runner: dict[str, Any] = {}
|
||||||
binding: dict[str, Any] = {}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
State 是可选 host-owned snapshot。Runner 也可以完全自管状态。
|
State 是可选 host-owned snapshot。Runner 也可以完全自管状态。
|
||||||
@@ -382,13 +382,12 @@ class AgentResources(BaseModel):
|
|||||||
models: list[ModelResource] = []
|
models: list[ModelResource] = []
|
||||||
tools: list[ToolResource] = []
|
tools: list[ToolResource] = []
|
||||||
knowledge_bases: list[KnowledgeBaseResource] = []
|
knowledge_bases: list[KnowledgeBaseResource] = []
|
||||||
artifacts: list[ArtifactResource] = []
|
files: list[FileResource] = []
|
||||||
storage: StorageResource = StorageResource()
|
storage: StorageResource = StorageResource()
|
||||||
history: HistoryResource = HistoryResource()
|
|
||||||
platform_capabilities: dict[str, Any] = {}
|
platform_capabilities: dict[str, Any] = {}
|
||||||
```
|
```
|
||||||
|
|
||||||
资源列表是本次 run 的授权结果。Runner 只能通过 `AgentRunAPIProxy` 访问这些资源。
|
资源列表是本次 run 的授权结果。History / Event / Artifact 访问通过 permissions、`ctx.context.available_apis` 和 Host 侧 run session 校验控制,不作为可枚举 resource list 暴露。Runner 只能通过 `AgentRunAPIProxy` 访问这些能力。
|
||||||
|
|
||||||
## 6. Result Stream
|
## 6. Result Stream
|
||||||
|
|
||||||
|
|||||||
@@ -89,13 +89,13 @@ class ResourcePolicy(pydantic.BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
allowed_model_uuids: list[str] | None = None
|
allowed_model_uuids: list[str] | None = None
|
||||||
"""Allowed model UUIDs. None means all authorized."""
|
"""Additional model UUID grants. None means no additional model grants."""
|
||||||
|
|
||||||
allowed_tool_names: list[str] | None = None
|
allowed_tool_names: list[str] | None = None
|
||||||
"""Allowed tool names. None means all authorized."""
|
"""Additional tool name grants. None means no additional tool grants."""
|
||||||
|
|
||||||
allowed_kb_uuids: list[str] | None = None
|
allowed_kb_uuids: list[str] | None = None
|
||||||
"""Allowed knowledge base UUIDs. None means all authorized."""
|
"""Additional knowledge base UUID grants. None means no additional KB grants."""
|
||||||
|
|
||||||
allow_plugin_storage: bool = True
|
allow_plugin_storage: bool = True
|
||||||
"""Whether plugin storage is allowed."""
|
"""Whether plugin storage is allowed."""
|
||||||
|
|||||||
@@ -176,6 +176,13 @@ class AgentRunOrchestrator:
|
|||||||
runner_id=descriptor.id,
|
runner_id=descriptor.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Register incoming attachments so input/transcript artifact_refs are resolvable.
|
||||||
|
await self._register_input_artifacts(
|
||||||
|
event=event,
|
||||||
|
run_id=run_id,
|
||||||
|
runner_id=descriptor.id,
|
||||||
|
)
|
||||||
|
|
||||||
# Write user message to Transcript if message.received
|
# Write user message to Transcript if message.received
|
||||||
if event.event_type == 'message.received' and event.conversation_id:
|
if event.event_type == 'message.received' and event.conversation_id:
|
||||||
await self._write_user_transcript(
|
await self._write_user_transcript(
|
||||||
@@ -526,6 +533,97 @@ class AgentRunOrchestrator:
|
|||||||
event_time=datetime.datetime.fromtimestamp(event.event_time) if event.event_time else None,
|
event_time=datetime.datetime.fromtimestamp(event.event_time) if event.event_time else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _register_input_artifacts(
|
||||||
|
self,
|
||||||
|
event: AgentEventEnvelope,
|
||||||
|
run_id: str,
|
||||||
|
runner_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Register current-event attachments referenced by AgentInput."""
|
||||||
|
if not event.input or not event.input.attachments:
|
||||||
|
return
|
||||||
|
|
||||||
|
from .artifact_store import ArtifactStore
|
||||||
|
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
|
||||||
|
|
||||||
|
for attachment in event.input.attachments:
|
||||||
|
data = attachment.model_dump(mode='json') if hasattr(attachment, 'model_dump') else attachment
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
artifact_id = data.get('artifact_id')
|
||||||
|
artifact_type = data.get('artifact_type') or 'file'
|
||||||
|
if not artifact_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
content, parsed_mime_type = self._decode_attachment_content(data.get('content'))
|
||||||
|
url = data.get('url')
|
||||||
|
platform_ref_id = data.get('id')
|
||||||
|
storage_key = None
|
||||||
|
storage_type = 'metadata_only'
|
||||||
|
if content is None:
|
||||||
|
if url:
|
||||||
|
storage_key = url
|
||||||
|
storage_type = 'url'
|
||||||
|
elif platform_ref_id:
|
||||||
|
storage_key = platform_ref_id
|
||||||
|
storage_type = 'platform_ref'
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
'input_attachment': True,
|
||||||
|
'input_source': data.get('source') or 'platform',
|
||||||
|
}
|
||||||
|
if url:
|
||||||
|
metadata['url'] = url
|
||||||
|
if platform_ref_id:
|
||||||
|
metadata['platform_ref_id'] = platform_ref_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
await store.register_artifact(
|
||||||
|
artifact_id=artifact_id,
|
||||||
|
artifact_type=artifact_type,
|
||||||
|
source='platform',
|
||||||
|
storage_key=storage_key,
|
||||||
|
storage_type=storage_type,
|
||||||
|
mime_type=data.get('mime_type') or parsed_mime_type,
|
||||||
|
name=data.get('name'),
|
||||||
|
size_bytes=data.get('size') or (len(content) if content is not None else None),
|
||||||
|
conversation_id=event.conversation_id,
|
||||||
|
run_id=run_id,
|
||||||
|
runner_id=runner_id,
|
||||||
|
bot_id=event.bot_id,
|
||||||
|
workspace_id=event.workspace_id,
|
||||||
|
metadata=metadata,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(
|
||||||
|
f'Failed to register input artifact {artifact_id}: {e}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _decode_attachment_content(
|
||||||
|
self,
|
||||||
|
content: typing.Any,
|
||||||
|
) -> tuple[bytes | None, str | None]:
|
||||||
|
"""Decode base64 attachment content, including data URLs."""
|
||||||
|
if not isinstance(content, str) or not content:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import binascii
|
||||||
|
|
||||||
|
mime_type = None
|
||||||
|
payload = content
|
||||||
|
if content.startswith('data:') and ',' in content:
|
||||||
|
header, payload = content.split(',', 1)
|
||||||
|
if ';base64' in header:
|
||||||
|
mime_type = header[5:].split(';', 1)[0] or None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return base64.b64decode(payload, validate=False), mime_type
|
||||||
|
except (binascii.Error, ValueError):
|
||||||
|
return None, mime_type
|
||||||
|
|
||||||
async def _write_user_transcript(
|
async def _write_user_transcript(
|
||||||
self,
|
self,
|
||||||
event: AgentEventEnvelope,
|
event: AgentEventEnvelope,
|
||||||
|
|||||||
@@ -250,6 +250,8 @@ class PersistentStateStore:
|
|||||||
|
|
||||||
Used by State API handlers.
|
Used by State API handlers.
|
||||||
"""
|
"""
|
||||||
|
state_key = normalize_state_key(state_key)
|
||||||
|
|
||||||
async with self._db_engine.connect() as conn:
|
async with self._db_engine.connect() as conn:
|
||||||
result = await conn.execute(
|
result = await conn.execute(
|
||||||
select(AgentRunnerState.value_json)
|
select(AgentRunnerState.value_json)
|
||||||
@@ -282,6 +284,8 @@ class PersistentStateStore:
|
|||||||
Used by State API handlers.
|
Used by State API handlers.
|
||||||
Context contains optional fields like bot_id, conversation_id, etc.
|
Context contains optional fields like bot_id, conversation_id, etc.
|
||||||
"""
|
"""
|
||||||
|
state_key = normalize_state_key(state_key)
|
||||||
|
|
||||||
# Validate and serialize value
|
# Validate and serialize value
|
||||||
value_json, error = self._validate_json_value(value, logger)
|
value_json, error = self._validate_json_value(value, logger)
|
||||||
if error:
|
if error:
|
||||||
@@ -344,6 +348,8 @@ class PersistentStateStore:
|
|||||||
|
|
||||||
Returns True if deleted, False if not found.
|
Returns True if deleted, False if not found.
|
||||||
"""
|
"""
|
||||||
|
state_key = normalize_state_key(state_key)
|
||||||
|
|
||||||
async with self._db_engine.begin() as conn:
|
async with self._db_engine.begin() as conn:
|
||||||
result = await conn.execute(
|
result = await conn.execute(
|
||||||
delete(AgentRunnerState)
|
delete(AgentRunnerState)
|
||||||
@@ -376,6 +382,7 @@ class PersistentStateStore:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prefix:
|
if prefix:
|
||||||
|
prefix = normalize_state_key(prefix)
|
||||||
query = query.where(
|
query = query.where(
|
||||||
AgentRunnerState.state_key.like(f'{prefix}%')
|
AgentRunnerState.state_key.like(f'{prefix}%')
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -103,12 +103,13 @@ class AgentResourceBuilder:
|
|||||||
models: list[ModelResource] = []
|
models: list[ModelResource] = []
|
||||||
seen_model_ids: set[str] = set()
|
seen_model_ids: set[str] = set()
|
||||||
|
|
||||||
# Check manifest permission
|
|
||||||
model_perms = manifest_perms.get('models', [])
|
model_perms = manifest_perms.get('models', [])
|
||||||
if 'invoke' not in model_perms and 'stream' not in model_perms:
|
allow_llm = 'invoke' in model_perms or 'stream' in model_perms
|
||||||
|
allow_rerank = 'rerank' in model_perms
|
||||||
|
if not allow_llm and not allow_rerank:
|
||||||
return models
|
return models
|
||||||
|
|
||||||
# Get model UUIDs from resource policy
|
# Get additional model UUID grants from resource policy.
|
||||||
allowed_uuids = resource_policy.allowed_model_uuids
|
allowed_uuids = resource_policy.allowed_model_uuids
|
||||||
|
|
||||||
# Add model resources from binding config schema
|
# Add model resources from binding config schema
|
||||||
@@ -117,10 +118,12 @@ class AgentResourceBuilder:
|
|||||||
seen_model_ids=seen_model_ids,
|
seen_model_ids=seen_model_ids,
|
||||||
descriptor=descriptor,
|
descriptor=descriptor,
|
||||||
runner_config=runner_config,
|
runner_config=runner_config,
|
||||||
|
include_llm=allow_llm,
|
||||||
|
include_rerank=allow_rerank,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add explicitly allowed models
|
# Add explicitly allowed models
|
||||||
if allowed_uuids:
|
if allowed_uuids and allow_llm:
|
||||||
for model_uuid in allowed_uuids:
|
for model_uuid in allowed_uuids:
|
||||||
await self._append_llm_model_resource(models, seen_model_ids, model_uuid)
|
await self._append_llm_model_resource(models, seen_model_ids, model_uuid)
|
||||||
|
|
||||||
@@ -168,13 +171,13 @@ class AgentResourceBuilder:
|
|||||||
if 'list' not in kb_perms and 'retrieve' not in kb_perms:
|
if 'list' not in kb_perms and 'retrieve' not in kb_perms:
|
||||||
return kb_resources
|
return kb_resources
|
||||||
|
|
||||||
# Get KB UUIDs from schema-defined config fields
|
# Get KB UUID grants from schema-defined config fields.
|
||||||
kb_uuids = config_schema.extract_knowledge_base_uuids(descriptor, runner_config)
|
kb_uuids = config_schema.extract_knowledge_base_uuids(descriptor, runner_config)
|
||||||
|
|
||||||
# Also check resource policy
|
# Also include resource policy grants.
|
||||||
allowed_uuids = resource_policy.allowed_kb_uuids
|
allowed_uuids = resource_policy.allowed_kb_uuids
|
||||||
if allowed_uuids:
|
if allowed_uuids:
|
||||||
kb_uuids = allowed_uuids
|
kb_uuids = list(dict.fromkeys([*kb_uuids, *allowed_uuids]))
|
||||||
|
|
||||||
for kb_uuid in kb_uuids:
|
for kb_uuid in kb_uuids:
|
||||||
try:
|
try:
|
||||||
@@ -210,12 +213,14 @@ class AgentResourceBuilder:
|
|||||||
seen_model_ids: set[str],
|
seen_model_ids: set[str],
|
||||||
descriptor: AgentRunnerDescriptor,
|
descriptor: AgentRunnerDescriptor,
|
||||||
runner_config: dict[str, typing.Any],
|
runner_config: dict[str, typing.Any],
|
||||||
|
include_llm: bool,
|
||||||
|
include_rerank: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Authorize model-like values selected through DynamicForm fields."""
|
"""Authorize model-like values selected through DynamicForm fields."""
|
||||||
for model_type, model_uuid in config_schema.iter_config_model_refs(descriptor, runner_config):
|
for model_type, model_uuid in config_schema.iter_config_model_refs(descriptor, runner_config):
|
||||||
if model_type == 'llm':
|
if model_type == 'llm' and include_llm:
|
||||||
await self._append_llm_model_resource(models, seen_model_ids, model_uuid)
|
await self._append_llm_model_resource(models, seen_model_ids, model_uuid)
|
||||||
elif model_type == 'rerank':
|
elif model_type == 'rerank' and include_rerank:
|
||||||
await self._append_rerank_model_resource(models, seen_model_ids, model_uuid)
|
await self._append_rerank_model_resource(models, seen_model_ids, model_uuid)
|
||||||
|
|
||||||
async def _append_llm_model_resource(
|
async def _append_llm_model_resource(
|
||||||
|
|||||||
@@ -73,9 +73,6 @@ class TranscriptStore:
|
|||||||
if content and len(content) > self.MAX_CONTENT_LENGTH:
|
if content and len(content) > self.MAX_CONTENT_LENGTH:
|
||||||
content = content[:self.MAX_CONTENT_LENGTH - 3] + "..."
|
content = content[:self.MAX_CONTENT_LENGTH - 3] + "..."
|
||||||
|
|
||||||
# Get next sequence number for this conversation
|
|
||||||
seq = await self._get_next_seq(conversation_id)
|
|
||||||
|
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
item = Transcript(
|
item = Transcript(
|
||||||
transcript_id=transcript_id,
|
transcript_id=transcript_id,
|
||||||
@@ -87,13 +84,15 @@ class TranscriptStore:
|
|||||||
content=content,
|
content=content,
|
||||||
content_json=json.dumps(content_json) if content_json else None,
|
content_json=json.dumps(content_json) if content_json else None,
|
||||||
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
|
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
|
||||||
seq=seq,
|
seq=0,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
runner_id=runner_id,
|
runner_id=runner_id,
|
||||||
created_at=datetime.datetime.utcnow(),
|
created_at=datetime.datetime.utcnow(),
|
||||||
metadata_json=json.dumps(metadata) if metadata else None,
|
metadata_json=json.dumps(metadata) if metadata else None,
|
||||||
)
|
)
|
||||||
session.add(item)
|
session.add(item)
|
||||||
|
await session.flush()
|
||||||
|
item.seq = item.id or await self._get_next_seq(conversation_id)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return transcript_id
|
return transcript_id
|
||||||
@@ -253,7 +252,7 @@ class TranscriptStore:
|
|||||||
return count > 0
|
return count > 0
|
||||||
|
|
||||||
async def _get_next_seq(self, conversation_id: str) -> int:
|
async def _get_next_seq(self, conversation_id: str) -> int:
|
||||||
"""Get the next sequence number for a conversation."""
|
"""Fallback next sequence number for stores that cannot expose autoincrement IDs."""
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
sqlalchemy.select(sqlalchemy.func.max(Transcript.seq))
|
sqlalchemy.select(sqlalchemy.func.max(Transcript.seq))
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class Transcript(Base):
|
|||||||
|
|
||||||
# Sequence for cursor-based pagination
|
# Sequence for cursor-based pagination
|
||||||
seq = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, index=True)
|
seq = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, index=True)
|
||||||
"""Sequence number within conversation (auto-increment per conversation)."""
|
"""Monotonic cursor sequence for pagination."""
|
||||||
|
|
||||||
# Context
|
# Context
|
||||||
run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
run_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True, index=True)
|
||||||
|
|||||||
+156
-144
@@ -141,6 +141,70 @@ def _validate_artifact_access(
|
|||||||
return False, f'Artifact {operation} access denied: artifact not in session conversation and not created by this run'
|
return False, f'Artifact {operation} access denied: artifact not in session conversation and not created by this run'
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_agent_run_session(
|
||||||
|
run_id: str,
|
||||||
|
caller_plugin_identity: str | None,
|
||||||
|
ap: app.Application,
|
||||||
|
api_name: str,
|
||||||
|
permission_group: str | None = None,
|
||||||
|
permission_operation: str | None = None,
|
||||||
|
) -> Union[tuple[None, handler.ActionResponse], tuple[Any, None]]:
|
||||||
|
"""Validate an AgentRunner pull API run session and optional manifest permission."""
|
||||||
|
session_registry = get_session_registry()
|
||||||
|
session = await session_registry.get(run_id)
|
||||||
|
if not session:
|
||||||
|
return None, handler.ActionResponse.error(
|
||||||
|
message=f'Run session {run_id} not found or expired'
|
||||||
|
)
|
||||||
|
|
||||||
|
session_plugin_identity = session.get('plugin_identity')
|
||||||
|
if session_plugin_identity:
|
||||||
|
if not caller_plugin_identity:
|
||||||
|
return None, handler.ActionResponse.error(
|
||||||
|
message=f'caller_plugin_identity is required for run_id {run_id}'
|
||||||
|
)
|
||||||
|
if caller_plugin_identity != session_plugin_identity:
|
||||||
|
ap.logger.warning(
|
||||||
|
f'{api_name}: caller_plugin_identity {caller_plugin_identity} '
|
||||||
|
f'does not match session plugin_identity {session_plugin_identity}'
|
||||||
|
)
|
||||||
|
return None, handler.ActionResponse.error(
|
||||||
|
message=f'Plugin identity mismatch for run_id {run_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if permission_group and permission_operation:
|
||||||
|
permissions = session.get('permissions', {})
|
||||||
|
allowed_operations = permissions.get(permission_group, [])
|
||||||
|
if permission_operation not in allowed_operations:
|
||||||
|
return None, handler.ActionResponse.error(
|
||||||
|
message=f'{api_name} access not authorized'
|
||||||
|
)
|
||||||
|
|
||||||
|
return session, None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_run_conversation(
|
||||||
|
session: dict[str, Any],
|
||||||
|
requested_conversation_id: str | None,
|
||||||
|
api_name: str,
|
||||||
|
) -> tuple[str | None, handler.ActionResponse | None]:
|
||||||
|
"""Resolve and enforce current-run conversation scope."""
|
||||||
|
session_conversation_id = session.get('conversation_id')
|
||||||
|
|
||||||
|
if requested_conversation_id:
|
||||||
|
if not session_conversation_id:
|
||||||
|
return None, handler.ActionResponse.error(
|
||||||
|
message=f'{api_name} is not available without a run conversation'
|
||||||
|
)
|
||||||
|
if requested_conversation_id != session_conversation_id:
|
||||||
|
return None, handler.ActionResponse.error(
|
||||||
|
message=f'Conversation {requested_conversation_id} is not accessible by this run'
|
||||||
|
)
|
||||||
|
return requested_conversation_id, None
|
||||||
|
|
||||||
|
return session_conversation_id, None
|
||||||
|
|
||||||
|
|
||||||
def _normalize_uuid_list(values: Any) -> list[str]:
|
def _normalize_uuid_list(values: Any) -> list[str]:
|
||||||
"""Normalize a user/config supplied UUID list while preserving order."""
|
"""Normalize a user/config supplied UUID list while preserving order."""
|
||||||
if not isinstance(values, list):
|
if not isinstance(values, list):
|
||||||
@@ -1197,7 +1261,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
kb_id = data['kb_id']
|
kb_id = data['kb_id']
|
||||||
query_text = data['query_text']
|
query_text = data['query_text']
|
||||||
top_k = data.get('top_k', 5)
|
top_k = data.get('top_k', 5)
|
||||||
filters = data.get('filters', {})
|
filters = data.get('filters') or {}
|
||||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||||
|
|
||||||
@@ -1271,7 +1335,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
kb_id = data['kb_id']
|
kb_id = data['kb_id']
|
||||||
query_text = data['query_text']
|
query_text = data['query_text']
|
||||||
top_k = data.get('top_k', 5)
|
top_k = data.get('top_k', 5)
|
||||||
filters = data.get('filters', {})
|
filters = data.get('filters') or {}
|
||||||
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
run_id = data.get('run_id') # Optional: present for AgentRunner calls
|
||||||
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
caller_plugin_identity = data.get('caller_plugin_identity') # Optional: for cross-plugin validation
|
||||||
|
|
||||||
@@ -1342,29 +1406,24 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
if not run_id:
|
if not run_id:
|
||||||
return handler.ActionResponse.error(message='run_id is required')
|
return handler.ActionResponse.error(message='run_id is required')
|
||||||
|
|
||||||
# Validate run session
|
session, error = await _validate_agent_run_session(
|
||||||
session_registry = get_session_registry()
|
run_id,
|
||||||
session = await session_registry.get(run_id)
|
caller_plugin_identity,
|
||||||
if not session:
|
self.ap,
|
||||||
return handler.ActionResponse.error(
|
'History page',
|
||||||
message=f'Run session {run_id} not found or expired'
|
permission_group='history',
|
||||||
)
|
permission_operation='page',
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return error
|
||||||
|
|
||||||
# Validate caller plugin identity (strict: required when session has plugin_identity)
|
conversation_id, scope_error = _resolve_run_conversation(
|
||||||
session_plugin_identity = session.get('plugin_identity')
|
session,
|
||||||
if session_plugin_identity:
|
conversation_id,
|
||||||
if not caller_plugin_identity:
|
'History page',
|
||||||
return handler.ActionResponse.error(
|
)
|
||||||
message=f'caller_plugin_identity is required for run_id {run_id}'
|
if scope_error:
|
||||||
)
|
return scope_error
|
||||||
if caller_plugin_identity != session_plugin_identity:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get conversation from session if not provided
|
|
||||||
if not conversation_id:
|
|
||||||
conversation_id = session.get('conversation_id')
|
|
||||||
|
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
return handler.ActionResponse.success(data={
|
return handler.ActionResponse.success(data={
|
||||||
@@ -1411,35 +1470,32 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
"""
|
"""
|
||||||
run_id = data.get('run_id')
|
run_id = data.get('run_id')
|
||||||
query_text = data.get('query', '')
|
query_text = data.get('query', '')
|
||||||
filters = data.get('filters', {})
|
filters = data.get('filters') or {}
|
||||||
top_k = data.get('top_k', 10)
|
top_k = data.get('top_k', 10)
|
||||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||||
|
|
||||||
if not run_id:
|
if not run_id:
|
||||||
return handler.ActionResponse.error(message='run_id is required')
|
return handler.ActionResponse.error(message='run_id is required')
|
||||||
|
|
||||||
# Validate run session
|
session, error = await _validate_agent_run_session(
|
||||||
session_registry = get_session_registry()
|
run_id,
|
||||||
session = await session_registry.get(run_id)
|
caller_plugin_identity,
|
||||||
if not session:
|
self.ap,
|
||||||
return handler.ActionResponse.error(
|
'History search',
|
||||||
message=f'Run session {run_id} not found or expired'
|
permission_group='history',
|
||||||
)
|
permission_operation='search',
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return error
|
||||||
|
|
||||||
# Validate caller plugin identity (strict: required when session has plugin_identity)
|
requested_conversation_id = filters.get('conversation_id')
|
||||||
session_plugin_identity = session.get('plugin_identity')
|
conversation_id, scope_error = _resolve_run_conversation(
|
||||||
if session_plugin_identity:
|
session,
|
||||||
if not caller_plugin_identity:
|
requested_conversation_id,
|
||||||
return handler.ActionResponse.error(
|
'History search',
|
||||||
message=f'caller_plugin_identity is required for run_id {run_id}'
|
)
|
||||||
)
|
if scope_error:
|
||||||
if caller_plugin_identity != session_plugin_identity:
|
return scope_error
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get conversation from session or filters
|
|
||||||
conversation_id = filters.get('conversation_id') or session.get('conversation_id')
|
|
||||||
|
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
return handler.ActionResponse.success(data={
|
return handler.ActionResponse.success(data={
|
||||||
@@ -1453,10 +1509,11 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
store = TranscriptStore(self.ap.persistence_mgr.get_db_engine())
|
store = TranscriptStore(self.ap.persistence_mgr.get_db_engine())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
safe_filters = {k: v for k, v in filters.items() if k != 'conversation_id'}
|
||||||
items = await store.search_transcript(
|
items = await store.search_transcript(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
filters=filters,
|
filters=safe_filters,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1485,25 +1542,16 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
if not event_id:
|
if not event_id:
|
||||||
return handler.ActionResponse.error(message='event_id is required')
|
return handler.ActionResponse.error(message='event_id is required')
|
||||||
|
|
||||||
# Validate run session
|
session, error = await _validate_agent_run_session(
|
||||||
session_registry = get_session_registry()
|
run_id,
|
||||||
session = await session_registry.get(run_id)
|
caller_plugin_identity,
|
||||||
if not session:
|
self.ap,
|
||||||
return handler.ActionResponse.error(
|
'Event get',
|
||||||
message=f'Run session {run_id} not found or expired'
|
permission_group='events',
|
||||||
)
|
permission_operation='get',
|
||||||
|
)
|
||||||
# Validate caller plugin identity (strict: required when session has plugin_identity)
|
if error:
|
||||||
session_plugin_identity = session.get('plugin_identity')
|
return error
|
||||||
if session_plugin_identity:
|
|
||||||
if not caller_plugin_identity:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message=f'caller_plugin_identity is required for run_id {run_id}'
|
|
||||||
)
|
|
||||||
if caller_plugin_identity != session_plugin_identity:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get event
|
# Get event
|
||||||
from ..agent.runner.event_log_store import EventLogStore
|
from ..agent.runner.event_log_store import EventLogStore
|
||||||
@@ -1516,9 +1564,12 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
message=f'Event {event_id} not found'
|
message=f'Event {event_id} not found'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate event is in the same conversation as the run
|
# Validate event is in the same conversation as the run, or was created by the same run.
|
||||||
session_conversation_id = session.get('conversation_id')
|
session_conversation_id = session.get('conversation_id')
|
||||||
if session_conversation_id and event.get('conversation_id') != session_conversation_id:
|
event_run_id = event.get('run_id')
|
||||||
|
if event_run_id and event_run_id == run_id:
|
||||||
|
return handler.ActionResponse.success(data=event)
|
||||||
|
if not session_conversation_id or event.get('conversation_id') != session_conversation_id:
|
||||||
return handler.ActionResponse.error(
|
return handler.ActionResponse.error(
|
||||||
message=f'Event {event_id} is not accessible by this run'
|
message=f'Event {event_id} is not accessible by this run'
|
||||||
)
|
)
|
||||||
@@ -1544,29 +1595,24 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
if not run_id:
|
if not run_id:
|
||||||
return handler.ActionResponse.error(message='run_id is required')
|
return handler.ActionResponse.error(message='run_id is required')
|
||||||
|
|
||||||
# Validate run session
|
session, error = await _validate_agent_run_session(
|
||||||
session_registry = get_session_registry()
|
run_id,
|
||||||
session = await session_registry.get(run_id)
|
caller_plugin_identity,
|
||||||
if not session:
|
self.ap,
|
||||||
return handler.ActionResponse.error(
|
'Event page',
|
||||||
message=f'Run session {run_id} not found or expired'
|
permission_group='events',
|
||||||
)
|
permission_operation='page',
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return error
|
||||||
|
|
||||||
# Validate caller plugin identity (strict: required when session has plugin_identity)
|
conversation_id, scope_error = _resolve_run_conversation(
|
||||||
session_plugin_identity = session.get('plugin_identity')
|
session,
|
||||||
if session_plugin_identity:
|
conversation_id,
|
||||||
if not caller_plugin_identity:
|
'Event page',
|
||||||
return handler.ActionResponse.error(
|
)
|
||||||
message=f'caller_plugin_identity is required for run_id {run_id}'
|
if scope_error:
|
||||||
)
|
return scope_error
|
||||||
if caller_plugin_identity != session_plugin_identity:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get conversation from session if not provided
|
|
||||||
if not conversation_id:
|
|
||||||
conversation_id = session.get('conversation_id')
|
|
||||||
|
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
return handler.ActionResponse.success(data={
|
return handler.ActionResponse.success(data={
|
||||||
@@ -1620,33 +1666,16 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
if not artifact_id:
|
if not artifact_id:
|
||||||
return handler.ActionResponse.error(message='artifact_id is required')
|
return handler.ActionResponse.error(message='artifact_id is required')
|
||||||
|
|
||||||
# Validate run session
|
session, error = await _validate_agent_run_session(
|
||||||
session_registry = get_session_registry()
|
run_id,
|
||||||
session = await session_registry.get(run_id)
|
caller_plugin_identity,
|
||||||
if not session:
|
self.ap,
|
||||||
return handler.ActionResponse.error(
|
'Artifact metadata',
|
||||||
message=f'Run session {run_id} not found or expired'
|
permission_group='artifacts',
|
||||||
)
|
permission_operation='metadata',
|
||||||
|
)
|
||||||
# Validate caller plugin identity (strict: required when session has plugin_identity)
|
if error:
|
||||||
session_plugin_identity = session.get('plugin_identity')
|
return error
|
||||||
if session_plugin_identity:
|
|
||||||
if not caller_plugin_identity:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message=f'caller_plugin_identity is required for run_id {run_id}'
|
|
||||||
)
|
|
||||||
if caller_plugin_identity != session_plugin_identity:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check artifact permission from session.permissions (from descriptor.permissions)
|
|
||||||
permissions = session.get('permissions', {})
|
|
||||||
artifact_permissions = permissions.get('artifacts', [])
|
|
||||||
if 'metadata' not in artifact_permissions:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message='Artifact metadata access not authorized'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get artifact metadata
|
# Get artifact metadata
|
||||||
from ..agent.runner.artifact_store import ArtifactStore
|
from ..agent.runner.artifact_store import ArtifactStore
|
||||||
@@ -1708,33 +1737,16 @@ class RuntimeConnectionHandler(handler.Handler):
|
|||||||
if limit <= 0:
|
if limit <= 0:
|
||||||
return handler.ActionResponse.error(message='limit must be > 0')
|
return handler.ActionResponse.error(message='limit must be > 0')
|
||||||
|
|
||||||
# Validate run session
|
session, error = await _validate_agent_run_session(
|
||||||
session_registry = get_session_registry()
|
run_id,
|
||||||
session = await session_registry.get(run_id)
|
caller_plugin_identity,
|
||||||
if not session:
|
self.ap,
|
||||||
return handler.ActionResponse.error(
|
'Artifact read',
|
||||||
message=f'Run session {run_id} not found or expired'
|
permission_group='artifacts',
|
||||||
)
|
permission_operation='read',
|
||||||
|
)
|
||||||
# Validate caller plugin identity (strict: required when session has plugin_identity)
|
if error:
|
||||||
session_plugin_identity = session.get('plugin_identity')
|
return error
|
||||||
if session_plugin_identity:
|
|
||||||
if not caller_plugin_identity:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message=f'caller_plugin_identity is required for run_id {run_id}'
|
|
||||||
)
|
|
||||||
if caller_plugin_identity != session_plugin_identity:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check artifact permission from session.permissions (from descriptor.permissions)
|
|
||||||
permissions = session.get('permissions', {})
|
|
||||||
artifact_permissions = permissions.get('artifacts', [])
|
|
||||||
if 'read' not in artifact_permissions:
|
|
||||||
return handler.ActionResponse.error(
|
|
||||||
message='Artifact read access not authorized'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get artifact metadata first to validate access
|
# Get artifact metadata first to validate access
|
||||||
from ..agent.runner.artifact_store import ArtifactStore
|
from ..agent.runner.artifact_store import ArtifactStore
|
||||||
|
|||||||
@@ -0,0 +1,146 @@
|
|||||||
|
"""Tests for AgentRunner history/event pull API authorization."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
|
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry
|
||||||
|
from langbot.pkg.plugin.handler import RuntimeConnectionHandler
|
||||||
|
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||||
|
|
||||||
|
from .conftest import make_resources
|
||||||
|
|
||||||
|
|
||||||
|
class FakeConnection:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FakeApplication:
|
||||||
|
def __init__(self, db_engine):
|
||||||
|
self.logger = MagicMock()
|
||||||
|
self.persistence_mgr = MagicMock()
|
||||||
|
self.persistence_mgr.get_db_engine = MagicMock(return_value=db_engine)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_registry(monkeypatch):
|
||||||
|
registry = AgentRunSessionRegistry()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
'langbot.pkg.agent.runner.session_registry._global_registry',
|
||||||
|
registry,
|
||||||
|
)
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_engine():
|
||||||
|
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def _handler(db_engine, session_registry):
|
||||||
|
async def fake_disconnect():
|
||||||
|
return True
|
||||||
|
|
||||||
|
fake_app = FakeApplication(db_engine)
|
||||||
|
return RuntimeConnectionHandler(FakeConnection(), fake_disconnect, fake_app)
|
||||||
|
|
||||||
|
|
||||||
|
async def _register_session(
|
||||||
|
session_registry,
|
||||||
|
*,
|
||||||
|
run_id='run_1',
|
||||||
|
conversation_id='conv_1',
|
||||||
|
permissions=None,
|
||||||
|
):
|
||||||
|
await session_registry.register(
|
||||||
|
run_id=run_id,
|
||||||
|
runner_id='plugin:test/runner/default',
|
||||||
|
query_id=None,
|
||||||
|
plugin_identity='test/runner',
|
||||||
|
resources=make_resources(),
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
permissions=permissions or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_page_requires_manifest_permission(session_registry, db_engine):
|
||||||
|
await _register_session(session_registry, permissions={'history': []})
|
||||||
|
handler = _handler(db_engine, session_registry)
|
||||||
|
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||||
|
|
||||||
|
result = await history_page({
|
||||||
|
'run_id': 'run_1',
|
||||||
|
'caller_plugin_identity': 'test/runner',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.code != 0
|
||||||
|
assert 'not authorized' in result.message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_page_rejects_cross_conversation(session_registry, db_engine):
|
||||||
|
await _register_session(session_registry, permissions={'history': ['page']})
|
||||||
|
handler = _handler(db_engine, session_registry)
|
||||||
|
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||||
|
|
||||||
|
result = await history_page({
|
||||||
|
'run_id': 'run_1',
|
||||||
|
'conversation_id': 'conv_other',
|
||||||
|
'caller_plugin_identity': 'test/runner',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.code != 0
|
||||||
|
assert 'not accessible' in result.message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_search_rejects_filter_conversation_override(session_registry, db_engine):
|
||||||
|
await _register_session(session_registry, permissions={'history': ['search']})
|
||||||
|
handler = _handler(db_engine, session_registry)
|
||||||
|
history_search = handler.actions[PluginToRuntimeAction.HISTORY_SEARCH.value]
|
||||||
|
|
||||||
|
result = await history_search({
|
||||||
|
'run_id': 'run_1',
|
||||||
|
'query': 'hello',
|
||||||
|
'filters': {'conversation_id': 'conv_other'},
|
||||||
|
'caller_plugin_identity': 'test/runner',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.code != 0
|
||||||
|
assert 'not accessible' in result.message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_page_requires_manifest_permission(session_registry, db_engine):
|
||||||
|
await _register_session(session_registry, permissions={'events': []})
|
||||||
|
handler = _handler(db_engine, session_registry)
|
||||||
|
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||||
|
|
||||||
|
result = await event_page({
|
||||||
|
'run_id': 'run_1',
|
||||||
|
'caller_plugin_identity': 'test/runner',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.code != 0
|
||||||
|
assert 'not authorized' in result.message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_page_rejects_cross_conversation(session_registry, db_engine):
|
||||||
|
await _register_session(session_registry, permissions={'events': ['page']})
|
||||||
|
handler = _handler(db_engine, session_registry)
|
||||||
|
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||||
|
|
||||||
|
result = await event_page({
|
||||||
|
'run_id': 'run_1',
|
||||||
|
'conversation_id': 'conv_other',
|
||||||
|
'caller_plugin_identity': 'test/runner',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert result.code != 0
|
||||||
|
assert 'not accessible' in result.message.lower()
|
||||||
@@ -18,6 +18,7 @@ def make_descriptor(
|
|||||||
*,
|
*,
|
||||||
permissions: dict | None = None,
|
permissions: dict | None = None,
|
||||||
config_schema: list[dict] | None = None,
|
config_schema: list[dict] | None = None,
|
||||||
|
capabilities: dict | None = None,
|
||||||
) -> AgentRunnerDescriptor:
|
) -> AgentRunnerDescriptor:
|
||||||
return AgentRunnerDescriptor(
|
return AgentRunnerDescriptor(
|
||||||
id=RUNNER_ID,
|
id=RUNNER_ID,
|
||||||
@@ -26,6 +27,7 @@ def make_descriptor(
|
|||||||
plugin_author='test',
|
plugin_author='test',
|
||||||
plugin_name='runner',
|
plugin_name='runner',
|
||||||
runner_name='default',
|
runner_name='default',
|
||||||
|
capabilities=capabilities or {},
|
||||||
permissions=permissions or {'models': ['invoke', 'stream']},
|
permissions=permissions or {'models': ['invoke', 'stream']},
|
||||||
config_schema=config_schema or [],
|
config_schema=config_schema or [],
|
||||||
)
|
)
|
||||||
@@ -99,6 +101,7 @@ async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app
|
|||||||
app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=get_model_by_uuid)
|
app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=get_model_by_uuid)
|
||||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(side_effect=get_rerank_model_by_uuid)
|
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(side_effect=get_rerank_model_by_uuid)
|
||||||
descriptor = make_descriptor(
|
descriptor = make_descriptor(
|
||||||
|
permissions={'models': ['invoke', 'stream', 'rerank']},
|
||||||
config_schema=[
|
config_schema=[
|
||||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||||
{'name': 'aux-model', 'type': 'llm-model-selector'},
|
{'name': 'aux-model', 'type': 'llm-model-selector'},
|
||||||
@@ -145,6 +148,33 @@ async def test_build_models_still_honors_manifest_permissions(app):
|
|||||||
app.model_mgr.get_rerank_model_by_uuid.assert_not_awaited()
|
app.model_mgr.get_rerank_model_by_uuid.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_models_authorizes_rerank_only_runner(app):
|
||||||
|
"""A rerank-only runner should receive config-selected rerank models."""
|
||||||
|
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||||
|
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
|
||||||
|
return_value=make_model(model_type='rerank', provider='rerank-provider')
|
||||||
|
)
|
||||||
|
descriptor = make_descriptor(
|
||||||
|
permissions={'models': ['rerank']},
|
||||||
|
config_schema=[
|
||||||
|
{'name': 'model', 'type': 'llm-model-selector'},
|
||||||
|
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
query = make_query({
|
||||||
|
'model': 'llm',
|
||||||
|
'rerank-model': 'rerank',
|
||||||
|
})
|
||||||
|
|
||||||
|
resources = await build_resources(app, query, descriptor)
|
||||||
|
|
||||||
|
assert resources['models'] == [
|
||||||
|
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider'},
|
||||||
|
]
|
||||||
|
app.model_mgr.get_model_by_uuid.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_build_models_deduplicates_query_and_config_models(app):
|
async def test_build_models_deduplicates_query_and_config_models(app):
|
||||||
"""A model selected by both preproc and runner config should appear once."""
|
"""A model selected by both preproc and runner config should appear once."""
|
||||||
@@ -197,3 +227,37 @@ async def test_build_tools_authorizes_query_declared_tools(app):
|
|||||||
'description': None,
|
'description': None,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_knowledge_bases_unions_config_and_policy_grants(app):
|
||||||
|
descriptor = make_descriptor(
|
||||||
|
capabilities={'knowledge_retrieval': True},
|
||||||
|
permissions={
|
||||||
|
'models': [],
|
||||||
|
'knowledge_bases': ['retrieve'],
|
||||||
|
},
|
||||||
|
config_schema=[
|
||||||
|
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
query = make_query(
|
||||||
|
{'knowledge-bases': ['kb_config']},
|
||||||
|
variables={'_knowledge_base_uuids': ['kb_policy']},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_kb(kb_uuid):
|
||||||
|
return SimpleNamespace(
|
||||||
|
uuid=kb_uuid,
|
||||||
|
get_name=lambda: f'name-{kb_uuid}',
|
||||||
|
knowledge_base_entity=SimpleNamespace(kb_type='default'),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(side_effect=get_kb)
|
||||||
|
|
||||||
|
resources = await build_resources(app, query, descriptor)
|
||||||
|
|
||||||
|
assert resources['knowledge_bases'] == [
|
||||||
|
{'kb_id': 'kb_config', 'kb_name': 'name-kb_config', 'kb_type': 'default'},
|
||||||
|
{'kb_id': 'kb_policy', 'kb_name': 'name-kb_policy', 'kb_type': 'default'},
|
||||||
|
]
|
||||||
|
|||||||
@@ -213,6 +213,30 @@ class TestPersistentStateStore:
|
|||||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||||
assert snapshot['conversation']['test_key'] == {'nested': 'value'}
|
assert snapshot['conversation']['test_key'] == {'nested': 'value'}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_api_methods_normalize_public_key_aliases(self, persistent_store):
|
||||||
|
scope_key = 'conversation:runner:binding:conv_001'
|
||||||
|
|
||||||
|
success, error = await persistent_store.state_set(
|
||||||
|
scope_key=scope_key,
|
||||||
|
state_key='conversation_id',
|
||||||
|
value='conv_001',
|
||||||
|
runner_id='plugin:test/my-runner/default',
|
||||||
|
binding_identity='binding_001',
|
||||||
|
scope='conversation',
|
||||||
|
)
|
||||||
|
|
||||||
|
assert success is True
|
||||||
|
assert error is None
|
||||||
|
assert await persistent_store.state_get(scope_key, 'external.conversation_id') == 'conv_001'
|
||||||
|
assert await persistent_store.state_get(scope_key, 'conversation_id') == 'conv_001'
|
||||||
|
|
||||||
|
keys, _ = await persistent_store.state_list(scope_key, prefix='conversation_id')
|
||||||
|
assert keys == ['external.conversation_id']
|
||||||
|
|
||||||
|
assert await persistent_store.state_delete(scope_key, 'conversation_id') is True
|
||||||
|
assert await persistent_store.state_get(scope_key, 'external.conversation_id') is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_binding_isolation(self, persistent_store):
|
async def test_binding_isolation(self, persistent_store):
|
||||||
descriptor = make_descriptor()
|
descriptor = make_descriptor()
|
||||||
|
|||||||
Reference in New Issue
Block a user