mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-13 17:26:04 +00:00
fix(agent-runner): harden run lifecycle and protocol stores
This commit is contained in:
@@ -324,6 +324,7 @@ class InlineContextPolicy(BaseModel):
|
||||
reason: str | None = None
|
||||
|
||||
class ContextAPICapabilities(BaseModel):
|
||||
prompt_get: bool = False
|
||||
history_page: bool = False
|
||||
history_search: bool = False
|
||||
event_get: bool = False
|
||||
@@ -473,8 +474,8 @@ Host 必须校验 `state.updated` 的 scope、key、value 大小和 JSON 可序
|
||||
|
||||
```python
|
||||
# Model
|
||||
await api.invoke_llm(model_id, messages, funcs=None, extra_args=None)
|
||||
async for chunk in api.invoke_llm_stream(model_id, messages, funcs=None, extra_args=None):
|
||||
await api.invoke_llm(llm_model_uuid, messages, funcs=None, extra_args=None)
|
||||
async for chunk in api.invoke_llm_stream(llm_model_uuid, messages, funcs=None, extra_args=None):
|
||||
...
|
||||
await api.invoke_rerank(rerank_model_id, query, documents, top_k=None)
|
||||
|
||||
@@ -486,13 +487,14 @@ await api.call_tool(tool_name, parameters)
|
||||
await api.retrieve_knowledge(kb_id, query_text, top_k=5, filters=None)
|
||||
|
||||
# History(返回 Transcript projection,不返回原始平台 payload)
|
||||
await api.get_prompt()
|
||||
await api.history_page(conversation_id=None, before_cursor=None, after_cursor=None,
|
||||
limit=50, direction="backward", include_artifacts=False)
|
||||
await api.history_search(query, filters=None, top_k=10)
|
||||
|
||||
# Event(返回稳定 event envelope 或受限 raw ref,不默认返回大 payload)
|
||||
await api.event_get(event_id)
|
||||
await api.event_page(before_cursor=None, limit=50)
|
||||
await api.event_page(conversation_id=None, event_types=None, before_cursor=None, limit=50)
|
||||
await api.steering_pull(mode="all", limit=None)
|
||||
|
||||
# Artifact(必须支持大小限制、MIME 校验、过期时间和授权范围)
|
||||
@@ -502,7 +504,7 @@ await api.artifact_read_range(artifact_id, offset=0, length=65536)
|
||||
|
||||
# State / Storage
|
||||
await api.state_get(scope, key); await api.state_set(scope, key, value); await api.state_delete(scope, key)
|
||||
await api.state_list(scope, prefix=None)
|
||||
await api.state_list(scope, prefix=None, limit=100)
|
||||
await api.get_plugin_storage(key); await api.set_plugin_storage(key, value); await api.delete_plugin_storage(key)
|
||||
await api.get_plugin_storage_keys()
|
||||
await api.get_workspace_storage(key); await api.set_workspace_storage(key, value); await api.delete_workspace_storage(key)
|
||||
@@ -513,6 +515,15 @@ await api.get_file(file_key)
|
||||
await api.get_langbot_version()
|
||||
```
|
||||
|
||||
`invoke_llm()` / `invoke_llm_stream()` 的第一个参数在 SDK 中命名为
|
||||
`llm_model_uuid`,wire payload 字段也是 `llm_model_uuid`。该值对 runner
|
||||
仍是 opaque identifier,不应解析其内部格式。
|
||||
|
||||
`get_prompt()` 返回当前 query-backed run 的 Host effective prompt messages:
|
||||
`list[Message]` 的 JSON 形式。该能力只在 `ctx.context.available_apis.prompt_get`
|
||||
为 true 时可用;没有 query 缓存、prompt 已过期或非 query entry run 时 Host
|
||||
可以返回错误或空列表。Runner 应在不可用时回退到自己的 config/prompt 策略。
|
||||
|
||||
`steering_pull(mode="all")` 是推荐默认:Host 按 claim 顺序返回全部 pending steering 输入并清空对应队列。`mode="one-at-a-time"` 仅用于 runner 主动节流,每次返回一条。Host 不合并多条用户消息;runner 负责在 turn 边界决定模型侧格式。
|
||||
|
||||
Steering 审计使用 EventLog 而不是 Transcript schema 扩展:被 active run 吸收的原始 `message.received` 事件保留原事件类型,并在 `metadata.steering` 标记 `status="queued"`、`trigger_behavior="absorbed_into_active_run"`、`claimed_by_run_id`、`claimed_runner_id`、`claimed_at`。Runner 成功 pull 后,Host 追加 `steering.injected` EventLog 记录,`metadata.steering.status="injected"` 并引用 `source_event_id`。若 run 结束时仍有已 claim 但未 pull 的 steering 输入,Host 追加 `steering.dropped` EventLog 记录,`metadata.steering.status="dropped"` 并引用 `source_event_id`;这不是用户消息事实的删除,只是 dispatch 终态。Transcript 继续只表示会话事实,不承担 dispatch 行为标记。
|
||||
|
||||
@@ -8,6 +8,7 @@ import uuid
|
||||
import base64
|
||||
import os
|
||||
|
||||
import aiofiles
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@@ -17,6 +18,23 @@ from ...entity.persistence.bstorage import BinaryStorage
|
||||
|
||||
_FILE_ARTIFACT_METADATA_KEY = '_langbot_file_artifact'
|
||||
_ARTIFACT_THREAD_METADATA_KEY = '_langbot_thread_id'
|
||||
UTC = datetime.timezone.utc
|
||||
|
||||
|
||||
def _utc_now() -> datetime.datetime:
|
||||
return datetime.datetime.now(UTC)
|
||||
|
||||
|
||||
def _as_utc(value: datetime.datetime) -> datetime.datetime:
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=UTC)
|
||||
return value.astimezone(UTC)
|
||||
|
||||
|
||||
def _datetime_to_epoch(value: datetime.datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(_as_utc(value).timestamp())
|
||||
|
||||
|
||||
class ArtifactStore:
|
||||
@@ -191,7 +209,7 @@ class ArtifactStore:
|
||||
runner_id=runner_id,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
created_at=_utc_now(),
|
||||
expires_at=expires_at,
|
||||
metadata_json=json.dumps(metadata_payload) if metadata_payload else None,
|
||||
)
|
||||
@@ -336,7 +354,7 @@ class ArtifactStore:
|
||||
}
|
||||
|
||||
if storage_type == 'file':
|
||||
return self._read_file_storage(record, artifact_id, offset, limit)
|
||||
return await self._read_file_storage(record, artifact_id, offset, limit)
|
||||
|
||||
# For other storage types, return storage reference
|
||||
# (caller can use file_key for chunked transfer)
|
||||
@@ -363,7 +381,7 @@ class ArtifactStore:
|
||||
backing lifecycle remains owned by the storage provider.
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.datetime.utcnow()
|
||||
now = _utc_now()
|
||||
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
@@ -416,7 +434,7 @@ class ArtifactStore:
|
||||
return None
|
||||
return row.value
|
||||
|
||||
def _read_file_storage(
|
||||
async def _read_file_storage(
|
||||
self,
|
||||
record: AgentArtifact,
|
||||
artifact_id: str,
|
||||
@@ -441,9 +459,9 @@ class ArtifactStore:
|
||||
if offset >= file_size:
|
||||
content = b''
|
||||
else:
|
||||
with open(real_path, 'rb') as f:
|
||||
f.seek(offset)
|
||||
content = f.read(limit)
|
||||
async with aiofiles.open(real_path, 'rb') as f:
|
||||
await f.seek(offset)
|
||||
content = await f.read(limit)
|
||||
|
||||
return {
|
||||
'artifact_id': artifact_id,
|
||||
@@ -491,8 +509,8 @@ class ArtifactStore:
|
||||
if row.expires_at is None:
|
||||
return False
|
||||
if now is None:
|
||||
now = datetime.datetime.utcnow()
|
||||
return row.expires_at <= now
|
||||
now = _utc_now()
|
||||
return _as_utc(row.expires_at) <= _as_utc(now)
|
||||
|
||||
def _row_to_public_dict(self, row: AgentArtifact) -> dict[str, typing.Any]:
|
||||
"""Convert an AgentArtifact row to public dict.
|
||||
@@ -511,7 +529,7 @@ class ArtifactStore:
|
||||
'conversation_id': row.conversation_id,
|
||||
'run_id': row.run_id,
|
||||
'runner_id': row.runner_id,
|
||||
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
|
||||
'expires_at': int(row.expires_at.timestamp()) if row.expires_at else None,
|
||||
'created_at': _datetime_to_epoch(row.created_at),
|
||||
'expires_at': _datetime_to_epoch(row.expires_at),
|
||||
'metadata': self._public_metadata(row.metadata_json),
|
||||
}
|
||||
|
||||
@@ -22,7 +22,14 @@ class AgentBindingResolver:
|
||||
event: AgentEventEnvelope,
|
||||
agents: list[AgentConfig],
|
||||
) -> AgentBinding:
|
||||
"""Resolve exactly one enabled Agent for the event."""
|
||||
"""Resolve exactly one enabled Agent for the event.
|
||||
|
||||
Callers that source agents from bot/workspace/global configuration must
|
||||
pre-filter candidates to the event scope before calling this resolver.
|
||||
The current AgentConfig model represents one already-selected product
|
||||
Agent and does not carry enough scope metadata to make that decision
|
||||
safely here.
|
||||
"""
|
||||
matches = [
|
||||
agent
|
||||
for agent in agents
|
||||
|
||||
@@ -13,6 +13,23 @@ from sqlalchemy.orm import sessionmaker
|
||||
from ...entity.persistence.event_log import EventLog
|
||||
|
||||
|
||||
UTC = datetime.timezone.utc
|
||||
|
||||
|
||||
def _utc_now() -> datetime.datetime:
|
||||
return datetime.datetime.now(UTC)
|
||||
|
||||
|
||||
def _datetime_to_epoch(value: datetime.datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
else:
|
||||
value = value.astimezone(UTC)
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class EventLogStore:
|
||||
"""Store for EventLog records.
|
||||
|
||||
@@ -107,7 +124,7 @@ class EventLogStore:
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
created_at=_utc_now(),
|
||||
)
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
@@ -277,7 +294,7 @@ class EventLogStore:
|
||||
'id': row.id,
|
||||
'event_id': row.event_id,
|
||||
'event_type': row.event_type,
|
||||
'event_time': int(row.event_time.timestamp()) if row.event_time else None,
|
||||
'event_time': _datetime_to_epoch(row.event_time),
|
||||
'source': row.source,
|
||||
'bot_id': row.bot_id,
|
||||
'workspace_id': row.workspace_id,
|
||||
@@ -293,6 +310,6 @@ class EventLogStore:
|
||||
'raw_ref': row.raw_ref,
|
||||
'run_id': row.run_id,
|
||||
'runner_id': row.runner_id,
|
||||
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
|
||||
'created_at': _datetime_to_epoch(row.created_at),
|
||||
'metadata': json.loads(row.metadata_json) if row.metadata_json else {},
|
||||
}
|
||||
|
||||
@@ -103,40 +103,6 @@ class AgentRunOrchestrator:
|
||||
|
||||
state_context = build_state_context(event, binding, descriptor)
|
||||
run_id = context['run_id']
|
||||
await self._session_registry.register(
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
query_id=session_query_id,
|
||||
plugin_identity=descriptor.get_plugin_id(),
|
||||
resources=resources,
|
||||
available_apis=context.get('context', {}).get('available_apis'),
|
||||
conversation_id=event.conversation_id,
|
||||
bot_id=event.bot_id,
|
||||
workspace_id=event.workspace_id,
|
||||
thread_id=event.thread_id,
|
||||
state_policy={
|
||||
'enable_state': binding.state_policy.enable_state,
|
||||
'state_scopes': list(binding.state_policy.state_scopes),
|
||||
},
|
||||
state_context=state_context,
|
||||
)
|
||||
|
||||
event_log_id = await self.journal.write_event_log(
|
||||
event=event,
|
||||
binding=binding,
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
)
|
||||
await self.journal.register_input_artifacts(
|
||||
event=event,
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
)
|
||||
if event.event_type == 'message.received' and event.conversation_id:
|
||||
await self.journal.write_user_transcript(
|
||||
event=event,
|
||||
event_log_id=event_log_id,
|
||||
)
|
||||
|
||||
pending_artifact_refs: list[dict[str, typing.Any]] = []
|
||||
seen_sequences: set[int] = set()
|
||||
@@ -144,6 +110,41 @@ class AgentRunOrchestrator:
|
||||
assistant_transcript_written = False
|
||||
|
||||
try:
|
||||
await self._session_registry.register(
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
query_id=session_query_id,
|
||||
plugin_identity=descriptor.get_plugin_id(),
|
||||
resources=resources,
|
||||
available_apis=context.get('context', {}).get('available_apis'),
|
||||
conversation_id=event.conversation_id,
|
||||
bot_id=event.bot_id,
|
||||
workspace_id=event.workspace_id,
|
||||
thread_id=event.thread_id,
|
||||
state_policy={
|
||||
'enable_state': binding.state_policy.enable_state,
|
||||
'state_scopes': list(binding.state_policy.state_scopes),
|
||||
},
|
||||
state_context=state_context,
|
||||
)
|
||||
|
||||
event_log_id = await self.journal.write_event_log(
|
||||
event=event,
|
||||
binding=binding,
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
)
|
||||
await self.journal.register_input_artifacts(
|
||||
event=event,
|
||||
run_id=run_id,
|
||||
runner_id=descriptor.id,
|
||||
)
|
||||
if event.event_type == 'message.received' and event.conversation_id:
|
||||
await self.journal.write_user_transcript(
|
||||
event=event,
|
||||
event_log_id=event_log_id,
|
||||
)
|
||||
|
||||
async for result_dict in self.invoker.invoke(descriptor, context):
|
||||
sequence = result_dict.get('sequence')
|
||||
if sequence is not None:
|
||||
|
||||
@@ -27,6 +27,7 @@ from .host_models import (
|
||||
StatePolicy,
|
||||
DeliveryPolicy,
|
||||
)
|
||||
from .config_migration import ConfigMigration
|
||||
from . import events as runner_events
|
||||
|
||||
|
||||
@@ -42,6 +43,7 @@ class QueryEntryAdapter:
|
||||
INTERNAL_PREFIX = '_'
|
||||
SENSITIVE_PATTERNS = ('secret', 'token', 'key', 'password', 'credential', 'api_key', 'apikey')
|
||||
PERMISSION_VARS = ('_pipeline_bound_plugins', '_authorized', '_permission')
|
||||
EVENT_DATA_MAX_STRING_BYTES = 512
|
||||
|
||||
@classmethod
|
||||
def query_to_event(
|
||||
@@ -103,8 +105,7 @@ class QueryEntryAdapter:
|
||||
) -> AgentConfig:
|
||||
"""Project the current Pipeline config container into target Agent config."""
|
||||
pipeline_config = query.pipeline_config or {}
|
||||
ai_config = pipeline_config.get('ai', {})
|
||||
runner_config = ai_config.get('runner_config', {}).get(runner_id, {})
|
||||
runner_config = ConfigMigration.resolve_runner_config(pipeline_config, runner_id)
|
||||
agent_id = getattr(query, 'pipeline_uuid', None)
|
||||
|
||||
# Build resource policy from current config
|
||||
@@ -199,12 +200,13 @@ class QueryEntryAdapter:
|
||||
event_data: dict[str, typing.Any] = {}
|
||||
if message_event and hasattr(message_event, 'model_dump'):
|
||||
try:
|
||||
event_data = message_event.model_dump(mode='json')
|
||||
raw_event_data = message_event.model_dump(mode='json')
|
||||
except TypeError:
|
||||
event_data = message_event.model_dump()
|
||||
raw_event_data = message_event.model_dump()
|
||||
except Exception:
|
||||
event_data = {}
|
||||
event_data.pop('source_platform_object', None)
|
||||
raw_event_data = {}
|
||||
if isinstance(raw_event_data, dict):
|
||||
event_data = cls._compact_event_data(raw_event_data)
|
||||
|
||||
source_event_type = None
|
||||
if message_event:
|
||||
@@ -231,6 +233,25 @@ class QueryEntryAdapter:
|
||||
data=event_data,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _compact_event_data(
|
||||
cls,
|
||||
event_data: dict[str, typing.Any],
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Keep only small scalar source-event metadata in event.data."""
|
||||
compact: dict[str, typing.Any] = {}
|
||||
for key, value in event_data.items():
|
||||
if key == 'source_platform_object' or key.startswith('_'):
|
||||
continue
|
||||
if value is None or isinstance(value, (bool, int, float)):
|
||||
compact[key] = value
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
if len(value.encode('utf-8')) <= cls.EVENT_DATA_MAX_STRING_BYTES:
|
||||
compact[key] = value
|
||||
continue
|
||||
return compact
|
||||
|
||||
@classmethod
|
||||
def _build_scoped_event_id(
|
||||
cls,
|
||||
@@ -430,6 +451,18 @@ class QueryEntryAdapter:
|
||||
import uuid
|
||||
|
||||
attachments: list[dict[str, typing.Any]] = []
|
||||
seen_keys: dict[tuple[str, str, str], set[str]] = {}
|
||||
|
||||
def add_attachment(attachment: dict[str, typing.Any]) -> None:
|
||||
key = cls._attachment_dedupe_key(attachment)
|
||||
if key is not None:
|
||||
source = str(attachment.get('source') or '')
|
||||
sources = seen_keys.setdefault(key, set())
|
||||
if source and sources and source not in sources:
|
||||
return
|
||||
if source:
|
||||
sources.add(source)
|
||||
attachments.append(attachment)
|
||||
|
||||
for elem in contents:
|
||||
elem_type = elem.get('type')
|
||||
@@ -437,21 +470,21 @@ class QueryEntryAdapter:
|
||||
|
||||
if elem_type == 'image_url':
|
||||
image_url = elem.get('image_url') or {}
|
||||
attachments.append({
|
||||
add_attachment({
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'image',
|
||||
'source': 'url',
|
||||
'url': image_url.get('url') if isinstance(image_url, dict) else str(image_url),
|
||||
})
|
||||
elif elem_type == 'image_base64':
|
||||
attachments.append({
|
||||
add_attachment({
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'image',
|
||||
'source': 'base64',
|
||||
'content': elem.get('image_base64'),
|
||||
})
|
||||
elif elem_type == 'file_url':
|
||||
attachments.append({
|
||||
add_attachment({
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'file',
|
||||
'source': 'url',
|
||||
@@ -459,7 +492,7 @@ class QueryEntryAdapter:
|
||||
'name': elem.get('file_name'),
|
||||
})
|
||||
elif elem_type == 'file_base64':
|
||||
attachments.append({
|
||||
add_attachment({
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'file',
|
||||
'source': 'base64',
|
||||
@@ -478,32 +511,56 @@ class QueryEntryAdapter:
|
||||
artifact_id = str(uuid.uuid4()) # Generate unique ID
|
||||
|
||||
if isinstance(component, platform_message.Image):
|
||||
attachments.append({
|
||||
image_id = component.image_id or None
|
||||
image_url = component.url or None
|
||||
image_base64 = component.base64 or None
|
||||
add_attachment({
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'image',
|
||||
'source': 'message_chain',
|
||||
'id': component.image_id or None,
|
||||
'url': component.url or None,
|
||||
'id': image_id,
|
||||
'url': image_url,
|
||||
'content': image_base64,
|
||||
})
|
||||
elif isinstance(component, platform_message.File):
|
||||
attachments.append({
|
||||
add_attachment({
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'file',
|
||||
'source': 'message_chain',
|
||||
'id': component.id or None,
|
||||
'name': component.name or None,
|
||||
'url': component.url or None,
|
||||
'content': component.base64 or None,
|
||||
})
|
||||
elif isinstance(component, platform_message.Voice):
|
||||
attachments.append({
|
||||
add_attachment({
|
||||
'artifact_id': artifact_id,
|
||||
'artifact_type': 'voice',
|
||||
'source': 'message_chain',
|
||||
'id': component.voice_id or None,
|
||||
'url': component.url or None,
|
||||
'content': component.base64 or None,
|
||||
})
|
||||
|
||||
return attachments
|
||||
|
||||
@classmethod
|
||||
def _attachment_dedupe_key(
|
||||
cls,
|
||||
attachment: dict[str, typing.Any],
|
||||
) -> tuple[str, str, str] | None:
|
||||
"""Return a stable key for the same attachment across content sources."""
|
||||
artifact_type = attachment.get('artifact_type')
|
||||
if not artifact_type:
|
||||
return None
|
||||
for field in ('id', 'url', 'content'):
|
||||
value = attachment.get(field)
|
||||
if value:
|
||||
if field == 'content':
|
||||
value = hashlib.sha256(str(value).encode('utf-8')).hexdigest()
|
||||
return str(artifact_type), field, str(value)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _build_delivery_context(
|
||||
cls,
|
||||
|
||||
@@ -165,7 +165,11 @@ class AgentRunJournal:
|
||||
input_json=input_json,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
event_time=datetime.datetime.fromtimestamp(event.event_time) if event.event_time else None,
|
||||
event_time=(
|
||||
datetime.datetime.fromtimestamp(event.event_time, datetime.timezone.utc)
|
||||
if event.event_time
|
||||
else None
|
||||
),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@@ -468,7 +472,10 @@ class AgentRunJournal:
|
||||
raw_event_time = event.get('event_time')
|
||||
if raw_event_time:
|
||||
try:
|
||||
event_time = datetime.datetime.fromtimestamp(raw_event_time)
|
||||
event_time = datetime.datetime.fromtimestamp(
|
||||
raw_event_time,
|
||||
datetime.timezone.utc,
|
||||
)
|
||||
except (TypeError, ValueError, OSError):
|
||||
event_time = None
|
||||
|
||||
|
||||
@@ -122,6 +122,9 @@ class AgentRunSessionRegistry:
|
||||
state_policy: State policy from binding (enable_state, state_scopes)
|
||||
state_context: Context for state API (scope_keys, binding_identity, etc.)
|
||||
"""
|
||||
if not isinstance(plugin_identity, str) or not plugin_identity.strip():
|
||||
raise ValueError('plugin_identity is required for agent run sessions')
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
available_apis = copy.deepcopy(available_apis or {})
|
||||
|
||||
@@ -14,6 +14,23 @@ from ...entity.persistence.transcript import Transcript
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
|
||||
|
||||
UTC = datetime.timezone.utc
|
||||
|
||||
|
||||
def _utc_now() -> datetime.datetime:
|
||||
return datetime.datetime.now(UTC)
|
||||
|
||||
|
||||
def _datetime_to_epoch(value: datetime.datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
else:
|
||||
value = value.astimezone(UTC)
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class TranscriptStore:
|
||||
"""Store for Transcript records.
|
||||
|
||||
@@ -94,7 +111,7 @@ class TranscriptStore:
|
||||
seq=0,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
created_at=_utc_now(),
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
session.add(item)
|
||||
@@ -371,7 +388,7 @@ class TranscriptStore:
|
||||
'content_json': json.loads(row.content_json) if row.content_json else None,
|
||||
'seq': row.seq,
|
||||
'cursor': str(row.seq),
|
||||
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
|
||||
'created_at': _datetime_to_epoch(row.created_at),
|
||||
'metadata': json.loads(row.metadata_json) if row.metadata_json else {},
|
||||
}
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class Transcript(Base):
|
||||
"""Artifact references as JSON string (list of ArtifactRef)."""
|
||||
|
||||
# Sequence for cursor-based pagination
|
||||
seq = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, index=True)
|
||||
seq = sqlalchemy.Column(sqlalchemy.Integer, nullable=False)
|
||||
"""Monotonic cursor sequence for pagination."""
|
||||
|
||||
# Context
|
||||
|
||||
@@ -29,6 +29,26 @@ def _load_config(config_value):
|
||||
return None
|
||||
|
||||
|
||||
def _update_config(conn, table_name: str, pipeline_uuid: str, migrated_config: dict) -> None:
|
||||
"""Write JSON config using a dialect-compatible bind."""
|
||||
config_json = json.dumps(migrated_config)
|
||||
if conn.dialect.name == 'postgresql':
|
||||
conn.execute(
|
||||
sa.text(
|
||||
f'UPDATE {table_name} '
|
||||
'SET config = CAST(:config AS JSON) '
|
||||
'WHERE uuid = :uuid'
|
||||
),
|
||||
{'config': config_json, 'uuid': pipeline_uuid},
|
||||
)
|
||||
return
|
||||
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE {table_name} SET config = :config WHERE uuid = :uuid'),
|
||||
{'config': config_json, 'uuid': pipeline_uuid},
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Normalize existing pipeline config containers."""
|
||||
conn = op.get_bind()
|
||||
@@ -56,10 +76,7 @@ def upgrade() -> None:
|
||||
|
||||
# Only update if config changed
|
||||
if json.dumps(config, sort_keys=True) != json.dumps(migrated_config, sort_keys=True):
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE {table_name} SET config = :config WHERE uuid = :uuid'),
|
||||
{'config': json.dumps(migrated_config), 'uuid': pipeline_uuid},
|
||||
)
|
||||
_update_config(conn, table_name, pipeline_uuid, migrated_config)
|
||||
except Exception:
|
||||
# Skip invalid configs
|
||||
continue
|
||||
|
||||
@@ -211,19 +211,23 @@ async def _validate_agent_run_session(
|
||||
)
|
||||
|
||||
session_plugin_identity = session.get('plugin_identity')
|
||||
if session_plugin_identity:
|
||||
if not caller_plugin_identity:
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'caller_plugin_identity is required for run_id {run_id}'
|
||||
)
|
||||
if caller_plugin_identity != session_plugin_identity:
|
||||
ap.logger.warning(
|
||||
f'{api_name}: caller_plugin_identity {caller_plugin_identity} '
|
||||
f'does not match session plugin_identity {session_plugin_identity}'
|
||||
)
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
||||
)
|
||||
if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip():
|
||||
ap.logger.warning(f'{api_name}: run_id {run_id} has no plugin_identity')
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'Run session {run_id} has no plugin_identity'
|
||||
)
|
||||
if not caller_plugin_identity:
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'caller_plugin_identity is required for run_id {run_id}'
|
||||
)
|
||||
if caller_plugin_identity != session_plugin_identity:
|
||||
ap.logger.warning(
|
||||
f'{api_name}: caller_plugin_identity {caller_plugin_identity} '
|
||||
f'does not match session plugin_identity {session_plugin_identity}'
|
||||
)
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
||||
)
|
||||
|
||||
if api_capability:
|
||||
available_apis = _get_run_authorization(session).get('available_apis', {})
|
||||
@@ -384,19 +388,25 @@ async def _validate_run_authorization(
|
||||
)
|
||||
|
||||
session_plugin_identity = session.get('plugin_identity')
|
||||
if session_plugin_identity:
|
||||
if not caller_plugin_identity:
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'caller_plugin_identity is required for run_id {run_id}',
|
||||
)
|
||||
if caller_plugin_identity != session_plugin_identity:
|
||||
ap.logger.warning(
|
||||
f'{resource_type.upper()}: caller_plugin_identity {caller_plugin_identity} '
|
||||
f'does not match session plugin_identity {session_plugin_identity}'
|
||||
)
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'Plugin identity mismatch: caller {caller_plugin_identity} is not authorized for run_id {run_id}',
|
||||
)
|
||||
if not isinstance(session_plugin_identity, str) or not session_plugin_identity.strip():
|
||||
ap.logger.warning(
|
||||
f'{resource_type.upper()}: run_id {run_id} has no plugin_identity'
|
||||
)
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'Run session {run_id} has no plugin_identity',
|
||||
)
|
||||
if not caller_plugin_identity:
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'caller_plugin_identity is required for run_id {run_id}',
|
||||
)
|
||||
if caller_plugin_identity != session_plugin_identity:
|
||||
ap.logger.warning(
|
||||
f'{resource_type.upper()}: caller_plugin_identity {caller_plugin_identity} '
|
||||
f'does not match session plugin_identity {session_plugin_identity}'
|
||||
)
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'Plugin identity mismatch: caller {caller_plugin_identity} is not authorized for run_id {run_id}',
|
||||
)
|
||||
|
||||
if not session_registry.is_resource_allowed(session, resource_type, resource_id, operation):
|
||||
ap.logger.warning(
|
||||
|
||||
@@ -106,6 +106,28 @@ class TestQueryToEventEnvelope:
|
||||
"message_id": "source-message-1",
|
||||
}
|
||||
|
||||
def test_query_to_event_keeps_large_payloads_out_of_event_data(self, mock_query):
|
||||
"""Large or nested platform payloads should not be duplicated into event.data."""
|
||||
source_event = Mock()
|
||||
source_event.type = "platform.message.created"
|
||||
source_event.time = 1700000000
|
||||
source_event.sender = None
|
||||
source_event.model_dump = Mock(return_value={
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
"message_chain": [{"type": "Image", "base64": "data:image/png;base64," + ("a" * 1024)}],
|
||||
"raw_text": "x" * 1024,
|
||||
"source_platform_object": {"large": "payload"},
|
||||
})
|
||||
mock_query.message_event = source_event
|
||||
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
assert event.data == {
|
||||
"type": "platform.message.created",
|
||||
"message_id": "source-message-1",
|
||||
}
|
||||
|
||||
def test_query_to_event_handles_missing_message_chain(self, mock_query):
|
||||
"""Test delivery context building when Query has no message_chain."""
|
||||
delattr(mock_query, "message_chain")
|
||||
@@ -137,6 +159,29 @@ class TestQueryConfigToAgentConfig:
|
||||
|
||||
assert agent_config.runner_id == "plugin:author/plugin/runner"
|
||||
|
||||
def test_config_to_agent_config_uses_legacy_runner_config_migration(self, mock_query):
|
||||
"""Temporary query adapter must share the normal runner config resolver."""
|
||||
mock_query.pipeline_config = {
|
||||
"ai": {
|
||||
"runner": {"runner": "local-agent"},
|
||||
"local-agent": {
|
||||
"model": "model-primary",
|
||||
"knowledge-base": "kb-001",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
agent_config = QueryEntryAdapter.config_to_agent_config(
|
||||
mock_query,
|
||||
"plugin:langbot/local-agent/default",
|
||||
)
|
||||
|
||||
assert agent_config.runner_config["model"] == {
|
||||
"primary": "model-primary",
|
||||
"fallbacks": [],
|
||||
}
|
||||
assert agent_config.runner_config["knowledge-bases"] == ["kb-001"]
|
||||
|
||||
def test_resolver_projects_agent_scope(self, mock_query):
|
||||
"""Test binding scope projection through the resolver."""
|
||||
event = QueryEntryAdapter.query_to_event(mock_query)
|
||||
|
||||
@@ -2024,6 +2024,75 @@ class TestCallerPluginIdentityValidation:
|
||||
|
||||
await registry.unregister('run_no_caller_identity')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_missing_plugin_identity_denied(self):
|
||||
"""Malformed legacy sessions without plugin_identity fail closed."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
registry = get_session_registry()
|
||||
resources = make_resources(models=[{'model_id': 'model_001'}])
|
||||
session = make_session(
|
||||
run_id='run_missing_session_identity',
|
||||
runner_id='plugin:test/runner/default',
|
||||
plugin_identity='',
|
||||
resources=resources,
|
||||
)
|
||||
async with registry._lock:
|
||||
registry._sessions['run_missing_session_identity'] = session
|
||||
|
||||
from langbot.pkg.plugin.handler import _validate_run_authorization
|
||||
|
||||
mock_ap = MagicMock()
|
||||
mock_ap.logger = MagicMock()
|
||||
|
||||
session, error = await _validate_run_authorization(
|
||||
'run_missing_session_identity',
|
||||
'model',
|
||||
'model_001',
|
||||
mock_ap,
|
||||
caller_plugin_identity='test/runner',
|
||||
)
|
||||
|
||||
assert session is None
|
||||
assert error is not None
|
||||
assert 'no plugin_identity' in error.message
|
||||
|
||||
await registry.unregister('run_missing_session_identity')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pull_api_session_missing_plugin_identity_denied(self):
|
||||
"""Pull API validation also fails closed for missing session identity."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
registry = get_session_registry()
|
||||
session = make_session(
|
||||
run_id='run_missing_pull_identity',
|
||||
runner_id='plugin:test/runner/default',
|
||||
plugin_identity='',
|
||||
available_apis={'history_page': True},
|
||||
)
|
||||
async with registry._lock:
|
||||
registry._sessions['run_missing_pull_identity'] = session
|
||||
|
||||
from langbot.pkg.plugin.handler import _validate_agent_run_session
|
||||
|
||||
mock_ap = MagicMock()
|
||||
mock_ap.logger = MagicMock()
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
'run_missing_pull_identity',
|
||||
'test/runner',
|
||||
mock_ap,
|
||||
'HISTORY_PAGE',
|
||||
'history_page',
|
||||
)
|
||||
|
||||
assert session is None
|
||||
assert error is not None
|
||||
assert 'no plugin_identity' in error.message
|
||||
|
||||
await registry.unregister('run_missing_pull_identity')
|
||||
|
||||
|
||||
class TestBackwardCompatStorageNoRunId:
|
||||
"""Tests for unscoped storage actions without run_id.
|
||||
|
||||
@@ -275,6 +275,37 @@ def test_context_builder_includes_consumable_base64_attachments():
|
||||
assert input_data.attachments[1].name == "hello.txt"
|
||||
|
||||
|
||||
def test_context_builder_deduplicates_message_chain_attachments():
|
||||
query = make_query()
|
||||
query.user_message = None
|
||||
query.message_chain = platform_message.MessageChain(
|
||||
[platform_message.Image(base64="data:image/jpeg;base64,aGVsbG8=")]
|
||||
)
|
||||
|
||||
input_data = QueryEntryAdapter._build_input(query)
|
||||
|
||||
assert [content.type for content in input_data.contents] == ["image_base64"]
|
||||
assert len(input_data.attachments) == 1
|
||||
assert input_data.attachments[0].artifact_type == "image"
|
||||
assert input_data.attachments[0].content == "data:image/jpeg;base64,aGVsbG8="
|
||||
|
||||
|
||||
def test_context_builder_preserves_same_source_duplicate_attachments():
|
||||
query = make_query()
|
||||
query.user_message = provider_message.Message(
|
||||
role="user",
|
||||
content=[
|
||||
provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="),
|
||||
provider_message.ContentElement.from_image_base64("data:image/png;base64,aGVsbG8="),
|
||||
],
|
||||
)
|
||||
query.message_chain = platform_message.MessageChain([])
|
||||
|
||||
input_data = QueryEntryAdapter._build_input(query)
|
||||
|
||||
assert [attachment.artifact_type for attachment in input_data.attachments] == ["image", "image"]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def clean_agent_state():
|
||||
"""Reset all singleton stores and create a test database engine."""
|
||||
@@ -546,6 +577,29 @@ async def test_orchestrator_unregisters_session_after_runner_failure(clean_agent
|
||||
assert await get_session_registry().get(context["run_id"]) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_unregisters_session_after_event_log_failure(clean_agent_state):
|
||||
"""Journal failures before runner invocation must not leave steerable sessions."""
|
||||
db_engine = clean_agent_state
|
||||
descriptor = make_descriptor()
|
||||
plugin_connector = FakePluginConnector(
|
||||
results=[
|
||||
{
|
||||
"type": "message.completed",
|
||||
"data": {"message": {"role": "assistant", "content": "unused"}},
|
||||
}
|
||||
]
|
||||
)
|
||||
orchestrator = AgentRunOrchestrator(FakeApplication(plugin_connector, db_engine), FakeRegistry(descriptor))
|
||||
orchestrator.journal.write_event_log = AsyncMock(side_effect=RuntimeError("journal unavailable"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="journal unavailable"):
|
||||
[message async for message in orchestrator.run_from_query(make_query())]
|
||||
|
||||
assert plugin_connector.contexts == []
|
||||
assert await get_session_registry().list_active_runs() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_enforces_total_runner_deadline(clean_agent_state):
|
||||
"""Test that orchestrator enforces total runner timeout."""
|
||||
|
||||
@@ -49,6 +49,20 @@ class TestSessionRegistryBasic:
|
||||
assert 'permissions' not in result
|
||||
assert '_authorized_ids' not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_requires_plugin_identity(self):
|
||||
"""Agent run sessions must always have an owning plugin identity."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
|
||||
with pytest.raises(ValueError, match='plugin_identity is required'):
|
||||
await registry.register(
|
||||
run_id='run_missing_identity',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='',
|
||||
resources=make_resources(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_freezes_authorization_snapshot(self):
|
||||
"""Register should freeze authorization data for the run."""
|
||||
|
||||
Reference in New Issue
Block a user