fix: enforce agent run API permissions

This commit is contained in:
huanghuoguoguo
2026-05-30 20:14:06 +08:00
parent bbe7666642
commit 93cd852061
12 changed files with 522 additions and 166 deletions

View File

@@ -89,13 +89,13 @@ class ResourcePolicy(pydantic.BaseModel):
"""
allowed_model_uuids: list[str] | None = None
"""Allowed model UUIDs. None means all authorized."""
"""Additional model UUID grants. None means no additional model grants."""
allowed_tool_names: list[str] | None = None
"""Allowed tool names. None means all authorized."""
"""Additional tool name grants. None means no additional tool grants."""
allowed_kb_uuids: list[str] | None = None
"""Allowed knowledge base UUIDs. None means all authorized."""
"""Additional knowledge base UUID grants. None means no additional KB grants."""
allow_plugin_storage: bool = True
"""Whether plugin storage is allowed."""

View File

@@ -176,6 +176,13 @@ class AgentRunOrchestrator:
runner_id=descriptor.id,
)
# Register incoming attachments so input/transcript artifact_refs are resolvable.
await self._register_input_artifacts(
event=event,
run_id=run_id,
runner_id=descriptor.id,
)
# Write user message to Transcript if message.received
if event.event_type == 'message.received' and event.conversation_id:
await self._write_user_transcript(
@@ -526,6 +533,97 @@ class AgentRunOrchestrator:
event_time=datetime.datetime.fromtimestamp(event.event_time) if event.event_time else None,
)
async def _register_input_artifacts(
self,
event: AgentEventEnvelope,
run_id: str,
runner_id: str,
) -> None:
"""Register current-event attachments referenced by AgentInput."""
if not event.input or not event.input.attachments:
return
from .artifact_store import ArtifactStore
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
for attachment in event.input.attachments:
data = attachment.model_dump(mode='json') if hasattr(attachment, 'model_dump') else attachment
if not isinstance(data, dict):
continue
artifact_id = data.get('artifact_id')
artifact_type = data.get('artifact_type') or 'file'
if not artifact_id:
continue
content, parsed_mime_type = self._decode_attachment_content(data.get('content'))
url = data.get('url')
platform_ref_id = data.get('id')
storage_key = None
storage_type = 'metadata_only'
if content is None:
if url:
storage_key = url
storage_type = 'url'
elif platform_ref_id:
storage_key = platform_ref_id
storage_type = 'platform_ref'
metadata = {
'input_attachment': True,
'input_source': data.get('source') or 'platform',
}
if url:
metadata['url'] = url
if platform_ref_id:
metadata['platform_ref_id'] = platform_ref_id
try:
await store.register_artifact(
artifact_id=artifact_id,
artifact_type=artifact_type,
source='platform',
storage_key=storage_key,
storage_type=storage_type,
mime_type=data.get('mime_type') or parsed_mime_type,
name=data.get('name'),
size_bytes=data.get('size') or (len(content) if content is not None else None),
conversation_id=event.conversation_id,
run_id=run_id,
runner_id=runner_id,
bot_id=event.bot_id,
workspace_id=event.workspace_id,
metadata=metadata,
content=content,
)
except Exception as e:
self.ap.logger.warning(
f'Failed to register input artifact {artifact_id}: {e}'
)
def _decode_attachment_content(
self,
content: typing.Any,
) -> tuple[bytes | None, str | None]:
"""Decode base64 attachment content, including data URLs."""
if not isinstance(content, str) or not content:
return None, None
import base64
import binascii
mime_type = None
payload = content
if content.startswith('data:') and ',' in content:
header, payload = content.split(',', 1)
if ';base64' in header:
mime_type = header[5:].split(';', 1)[0] or None
try:
return base64.b64decode(payload, validate=False), mime_type
except (binascii.Error, ValueError):
return None, mime_type
async def _write_user_transcript(
self,
event: AgentEventEnvelope,

View File

@@ -250,6 +250,8 @@ class PersistentStateStore:
Used by State API handlers.
"""
state_key = normalize_state_key(state_key)
async with self._db_engine.connect() as conn:
result = await conn.execute(
select(AgentRunnerState.value_json)
@@ -282,6 +284,8 @@ class PersistentStateStore:
Used by State API handlers.
Context contains optional fields like bot_id, conversation_id, etc.
"""
state_key = normalize_state_key(state_key)
# Validate and serialize value
value_json, error = self._validate_json_value(value, logger)
if error:
@@ -344,6 +348,8 @@ class PersistentStateStore:
Returns True if deleted, False if not found.
"""
state_key = normalize_state_key(state_key)
async with self._db_engine.begin() as conn:
result = await conn.execute(
delete(AgentRunnerState)
@@ -376,6 +382,7 @@ class PersistentStateStore:
)
if prefix:
prefix = normalize_state_key(prefix)
query = query.where(
AgentRunnerState.state_key.like(f'{prefix}%')
)

View File

@@ -103,12 +103,13 @@ class AgentResourceBuilder:
models: list[ModelResource] = []
seen_model_ids: set[str] = set()
# Check manifest permission
model_perms = manifest_perms.get('models', [])
if 'invoke' not in model_perms and 'stream' not in model_perms:
allow_llm = 'invoke' in model_perms or 'stream' in model_perms
allow_rerank = 'rerank' in model_perms
if not allow_llm and not allow_rerank:
return models
# Get model UUIDs from resource policy
# Get additional model UUID grants from resource policy.
allowed_uuids = resource_policy.allowed_model_uuids
# Add model resources from binding config schema
@@ -117,10 +118,12 @@ class AgentResourceBuilder:
seen_model_ids=seen_model_ids,
descriptor=descriptor,
runner_config=runner_config,
include_llm=allow_llm,
include_rerank=allow_rerank,
)
# Add explicitly allowed models
if allowed_uuids:
if allowed_uuids and allow_llm:
for model_uuid in allowed_uuids:
await self._append_llm_model_resource(models, seen_model_ids, model_uuid)
@@ -168,13 +171,13 @@ class AgentResourceBuilder:
if 'list' not in kb_perms and 'retrieve' not in kb_perms:
return kb_resources
# Get KB UUIDs from schema-defined config fields
# Get KB UUID grants from schema-defined config fields.
kb_uuids = config_schema.extract_knowledge_base_uuids(descriptor, runner_config)
# Also check resource policy
# Also include resource policy grants.
allowed_uuids = resource_policy.allowed_kb_uuids
if allowed_uuids:
kb_uuids = allowed_uuids
kb_uuids = list(dict.fromkeys([*kb_uuids, *allowed_uuids]))
for kb_uuid in kb_uuids:
try:
@@ -210,12 +213,14 @@ class AgentResourceBuilder:
seen_model_ids: set[str],
descriptor: AgentRunnerDescriptor,
runner_config: dict[str, typing.Any],
include_llm: bool,
include_rerank: bool,
) -> None:
"""Authorize model-like values selected through DynamicForm fields."""
for model_type, model_uuid in config_schema.iter_config_model_refs(descriptor, runner_config):
if model_type == 'llm':
if model_type == 'llm' and include_llm:
await self._append_llm_model_resource(models, seen_model_ids, model_uuid)
elif model_type == 'rerank':
elif model_type == 'rerank' and include_rerank:
await self._append_rerank_model_resource(models, seen_model_ids, model_uuid)
async def _append_llm_model_resource(

View File

@@ -73,9 +73,6 @@ class TranscriptStore:
if content and len(content) > self.MAX_CONTENT_LENGTH:
content = content[:self.MAX_CONTENT_LENGTH - 3] + "..."
# Get next sequence number for this conversation
seq = await self._get_next_seq(conversation_id)
async with self._session_factory() as session:
item = Transcript(
transcript_id=transcript_id,
@@ -87,13 +84,15 @@ class TranscriptStore:
content=content,
content_json=json.dumps(content_json) if content_json else None,
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
seq=seq,
seq=0,
run_id=run_id,
runner_id=runner_id,
created_at=datetime.datetime.utcnow(),
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(item)
await session.flush()
item.seq = item.id or await self._get_next_seq(conversation_id)
await session.commit()
return transcript_id
@@ -253,7 +252,7 @@ class TranscriptStore:
return count > 0
async def _get_next_seq(self, conversation_id: str) -> int:
"""Get the next sequence number for a conversation."""
"""Fallback next sequence number for stores that cannot expose autoincrement IDs."""
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(sqlalchemy.func.max(Transcript.seq))