mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-10 07:46:02 +00:00
fix(agent-runner): harden state and event APIs
This commit is contained in:
@@ -612,6 +612,7 @@ Runner 失败使用 `run.failed`:
|
||||
|
||||
- Host 在 `ctx.runtime.deadline_at` 下发总 deadline;SDK proxy 必须用该 deadline 限制单次 action timeout。
|
||||
- Host 可以取消 active run;Runtime 应尽力中断 runner。
|
||||
- Protocol v1 的 run 绑定当前 Host 进程和当前 runtime channel,不保证跨 Host 重启恢复。Host 重启、runtime channel 断开或 run session 丢失时,Runtime / remote daemon 必须 fail-fast 并尽力取消仍在执行的 runner,不得继续使用旧 `run_id` 调用 Host API。
|
||||
- Runner 支持中断时应返回或触发 `run.failed`,code 为 `cancelled`。
|
||||
- Host 必须 unregister active run session。
|
||||
|
||||
|
||||
@@ -12,6 +12,9 @@ from datetime import datetime
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy import select, delete, update
|
||||
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .descriptor import AgentRunnerDescriptor
|
||||
from .host_models import AgentEventEnvelope, AgentBinding
|
||||
@@ -87,6 +90,49 @@ class PersistentStateStore:
|
||||
|
||||
return json_str, None
|
||||
|
||||
async def _upsert_state_row(
|
||||
self,
|
||||
conn: typing.Any,
|
||||
values: dict[str, typing.Any],
|
||||
) -> None:
|
||||
"""Insert or update a state row by the logical scope/key identity."""
|
||||
update_values = {
|
||||
'value_json': values['value_json'],
|
||||
'updated_at': values['updated_at'],
|
||||
}
|
||||
constraint_columns = ['scope_key', 'state_key']
|
||||
dialect_name = self._db_engine.dialect.name
|
||||
|
||||
if dialect_name == 'sqlite':
|
||||
stmt = sqlite_insert(AgentRunnerState).values(**values)
|
||||
await conn.execute(
|
||||
stmt.on_conflict_do_update(
|
||||
index_elements=constraint_columns,
|
||||
set_=update_values,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if dialect_name == 'postgresql':
|
||||
stmt = postgresql_insert(AgentRunnerState).values(**values)
|
||||
await conn.execute(
|
||||
stmt.on_conflict_do_update(
|
||||
index_elements=constraint_columns,
|
||||
set_=update_values,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await conn.execute(sqlalchemy.insert(AgentRunnerState).values(**values))
|
||||
except IntegrityError:
|
||||
await conn.execute(
|
||||
update(AgentRunnerState)
|
||||
.where(AgentRunnerState.scope_key == values['scope_key'])
|
||||
.where(AgentRunnerState.state_key == values['state_key'])
|
||||
.values(**update_values)
|
||||
)
|
||||
|
||||
# ========== Async DB Operations ==========
|
||||
|
||||
async def build_snapshot_from_event(
|
||||
@@ -195,49 +241,29 @@ class PersistentStateStore:
|
||||
# Build context fields
|
||||
binding_identity = get_binding_identity(binding)
|
||||
|
||||
now = datetime.utcnow()
|
||||
async with self._db_engine.begin() as conn:
|
||||
# Check if entry exists
|
||||
result = await conn.execute(
|
||||
select(AgentRunnerState.id)
|
||||
.where(AgentRunnerState.scope_key == scope_key)
|
||||
.where(AgentRunnerState.state_key == key)
|
||||
await self._upsert_state_row(
|
||||
conn,
|
||||
{
|
||||
'runner_id': descriptor.id,
|
||||
'binding_identity': binding_identity,
|
||||
'scope': scope,
|
||||
'scope_key': scope_key,
|
||||
'state_key': key,
|
||||
'value_json': value_json,
|
||||
'bot_id': event.bot_id,
|
||||
'workspace_id': event.workspace_id,
|
||||
'conversation_id': event.conversation_id,
|
||||
'thread_id': event.thread_id,
|
||||
'actor_type': event.actor.actor_type if event.actor else None,
|
||||
'actor_id': event.actor.actor_id if event.actor else None,
|
||||
'subject_type': event.subject.subject_type if event.subject else None,
|
||||
'subject_id': event.subject.subject_id if event.subject else None,
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
},
|
||||
)
|
||||
existing = result.first()
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
if existing:
|
||||
# Update existing entry
|
||||
await conn.execute(
|
||||
update(AgentRunnerState)
|
||||
.where(AgentRunnerState.id == existing.id)
|
||||
.values(
|
||||
value_json=value_json,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Insert new entry
|
||||
await conn.execute(
|
||||
sqlalchemy.insert(AgentRunnerState).values(
|
||||
runner_id=descriptor.id,
|
||||
binding_identity=binding_identity,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
state_key=key,
|
||||
value_json=value_json,
|
||||
bot_id=event.bot_id,
|
||||
workspace_id=event.workspace_id,
|
||||
conversation_id=event.conversation_id,
|
||||
thread_id=event.thread_id,
|
||||
actor_type=event.actor.actor_type if event.actor else None,
|
||||
actor_id=event.actor.actor_id if event.actor else None,
|
||||
subject_type=event.subject.subject_type if event.subject else None,
|
||||
subject_id=event.subject.subject_id if event.subject else None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@@ -293,49 +319,29 @@ class PersistentStateStore:
|
||||
|
||||
context = context or {}
|
||||
|
||||
now = datetime.utcnow()
|
||||
async with self._db_engine.begin() as conn:
|
||||
# Check if entry exists
|
||||
result = await conn.execute(
|
||||
select(AgentRunnerState.id)
|
||||
.where(AgentRunnerState.scope_key == scope_key)
|
||||
.where(AgentRunnerState.state_key == state_key)
|
||||
await self._upsert_state_row(
|
||||
conn,
|
||||
{
|
||||
'runner_id': runner_id,
|
||||
'binding_identity': binding_identity,
|
||||
'scope': scope,
|
||||
'scope_key': scope_key,
|
||||
'state_key': state_key,
|
||||
'value_json': value_json,
|
||||
'bot_id': context.get('bot_id'),
|
||||
'workspace_id': context.get('workspace_id'),
|
||||
'conversation_id': context.get('conversation_id'),
|
||||
'thread_id': context.get('thread_id'),
|
||||
'actor_type': context.get('actor_type'),
|
||||
'actor_id': context.get('actor_id'),
|
||||
'subject_type': context.get('subject_type'),
|
||||
'subject_id': context.get('subject_id'),
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
},
|
||||
)
|
||||
existing = result.first()
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
if existing:
|
||||
# Update existing entry
|
||||
await conn.execute(
|
||||
update(AgentRunnerState)
|
||||
.where(AgentRunnerState.id == existing.id)
|
||||
.values(
|
||||
value_json=value_json,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Insert new entry
|
||||
await conn.execute(
|
||||
sqlalchemy.insert(AgentRunnerState).values(
|
||||
runner_id=runner_id,
|
||||
binding_identity=binding_identity,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
state_key=state_key,
|
||||
value_json=value_json,
|
||||
bot_id=context.get('bot_id'),
|
||||
workspace_id=context.get('workspace_id'),
|
||||
conversation_id=context.get('conversation_id'),
|
||||
thread_id=context.get('thread_id'),
|
||||
actor_type=context.get('actor_type'),
|
||||
actor_id=context.get('actor_id'),
|
||||
subject_type=context.get('subject_type'),
|
||||
subject_id=context.get('subject_id'),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
@@ -243,6 +243,33 @@ def _resolve_run_conversation(
|
||||
return session_conversation_id, None
|
||||
|
||||
|
||||
def _project_event_record_for_api(event: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Project EventLogStore rows onto the SDK AgentEventRecord DTO."""
|
||||
seq = event.get('seq') or event.get('id')
|
||||
return {
|
||||
'event_id': event.get('event_id'),
|
||||
'event_type': event.get('event_type'),
|
||||
'event_time': event.get('event_time'),
|
||||
'source': event.get('source'),
|
||||
'bot_id': event.get('bot_id'),
|
||||
'workspace_id': event.get('workspace_id'),
|
||||
'conversation_id': event.get('conversation_id'),
|
||||
'thread_id': event.get('thread_id'),
|
||||
'actor_type': event.get('actor_type'),
|
||||
'actor_id': event.get('actor_id'),
|
||||
'actor_name': event.get('actor_name'),
|
||||
'subject_type': event.get('subject_type'),
|
||||
'subject_id': event.get('subject_id'),
|
||||
'input_summary': event.get('input_summary'),
|
||||
'input_ref': event.get('input_ref'),
|
||||
'raw_ref': event.get('raw_ref'),
|
||||
'seq': seq,
|
||||
'cursor': event.get('cursor') or (str(seq) if seq is not None else None),
|
||||
'created_at': event.get('created_at'),
|
||||
'metadata': event.get('metadata') or {},
|
||||
}
|
||||
|
||||
|
||||
def _normalize_uuid_list(values: Any) -> list[str]:
|
||||
"""Normalize a user/config supplied UUID list while preserving order."""
|
||||
if not isinstance(values, list):
|
||||
@@ -1619,13 +1646,13 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
session_conversation_id = _get_run_authorization(session).get('conversation_id')
|
||||
event_run_id = event.get('run_id')
|
||||
if event_run_id and event_run_id == run_id:
|
||||
return handler.ActionResponse.success(data=event)
|
||||
return handler.ActionResponse.success(data=_project_event_record_for_api(event))
|
||||
if not session_conversation_id or event.get('conversation_id') != session_conversation_id:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Event {event_id} is not accessible by this run'
|
||||
)
|
||||
|
||||
return handler.ActionResponse.success(data=event)
|
||||
return handler.ActionResponse.success(data=_project_event_record_for_api(event))
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'EVENT_GET error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'Event get error: {e}')
|
||||
@@ -1689,7 +1716,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
)
|
||||
|
||||
return handler.ActionResponse.success(data={
|
||||
'items': items,
|
||||
'items': [_project_event_record_for_api(item) for item in items],
|
||||
'next_cursor': str(next_seq) if next_seq else None,
|
||||
'prev_cursor': None,
|
||||
'has_more': has_more,
|
||||
|
||||
@@ -6,8 +6,15 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from langbot.pkg.agent.runner.event_log_store import EventLogStore
|
||||
from langbot.pkg.agent.runner.session_registry import AgentRunSessionRegistry
|
||||
from langbot.pkg.entity.persistence import event_log as event_log_model
|
||||
from langbot.pkg.entity.persistence.base import Base
|
||||
from langbot.pkg.plugin.handler import RuntimeConnectionHandler
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.page_results import (
|
||||
AgentEventRecord,
|
||||
EventPage,
|
||||
)
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction
|
||||
|
||||
from .conftest import make_resources
|
||||
@@ -37,6 +44,9 @@ def session_registry(monkeypatch):
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine('sqlite+aiosqlite:///:memory:')
|
||||
assert event_log_model.EventLog.__tablename__ == 'event_log'
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
@@ -144,3 +154,69 @@ async def test_event_page_rejects_cross_conversation(session_registry, db_engine
|
||||
|
||||
assert result.code != 0
|
||||
assert 'not accessible' in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_get_returns_sdk_record_projection(session_registry, db_engine):
|
||||
await _register_session(session_registry, permissions={'events': ['get']})
|
||||
store = EventLogStore(db_engine)
|
||||
event_id = await store.append_event(
|
||||
event_id='evt_projection_1',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
conversation_id='conv_1',
|
||||
actor_type='user',
|
||||
actor_id='user_1',
|
||||
input_summary='hello',
|
||||
input_json={'internal': 'not part of AgentEventRecord'},
|
||||
run_id='run_1',
|
||||
runner_id='plugin:test/runner/default',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_get = handler.actions[PluginToRuntimeAction.EVENT_GET.value]
|
||||
|
||||
result = await event_get({
|
||||
'run_id': 'run_1',
|
||||
'event_id': event_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
AgentEventRecord.model_validate(result.data)
|
||||
assert 'id' not in result.data
|
||||
assert 'input_json' not in result.data
|
||||
assert 'run_id' not in result.data
|
||||
assert 'runner_id' not in result.data
|
||||
assert result.data['seq'] is not None
|
||||
assert result.data['cursor'] == str(result.data['seq'])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_returns_sdk_page_projection(session_registry, db_engine):
|
||||
await _register_session(session_registry, permissions={'events': ['page']})
|
||||
store = EventLogStore(db_engine)
|
||||
await store.append_event(
|
||||
event_id='evt_projection_page_1',
|
||||
event_type='message.received',
|
||||
source='platform',
|
||||
conversation_id='conv_1',
|
||||
input_json={'internal': 'not part of AgentEventRecord'},
|
||||
run_id='run_other',
|
||||
runner_id='plugin:test/runner/default',
|
||||
)
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
result = await event_page({
|
||||
'run_id': 'run_1',
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
|
||||
assert result.code == 0
|
||||
page = EventPage.model_validate(result.data)
|
||||
assert len(page.items) == 1
|
||||
item = result.data['items'][0]
|
||||
assert 'id' not in item
|
||||
assert 'input_json' not in item
|
||||
assert 'run_id' not in item
|
||||
assert 'runner_id' not in item
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for persistent AgentRunner state store."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
@@ -212,6 +213,26 @@ class TestPersistentStateStore:
|
||||
snapshot = await persistent_store.build_snapshot_from_event(event, binding, descriptor)
|
||||
assert snapshot['conversation']['test_key'] == {'nested': 'value'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_first_state_set_uses_upsert(self, persistent_store):
|
||||
scope_key = 'conversation:runner:binding:conv_concurrent'
|
||||
|
||||
async def set_value(value: int):
|
||||
return await persistent_store.state_set(
|
||||
scope_key=scope_key,
|
||||
state_key='external.concurrent',
|
||||
value={'value': value},
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
binding_identity='binding_001',
|
||||
scope='conversation',
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*(set_value(value) for value in range(8)))
|
||||
|
||||
assert all(success is True and error is None for success, error in results)
|
||||
stored = await persistent_store.state_get(scope_key, 'external.concurrent')
|
||||
assert stored in [{'value': value} for value in range(8)]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_api_methods_normalize_public_key_aliases(self, persistent_store):
|
||||
scope_key = 'conversation:runner:binding:conv_001'
|
||||
|
||||
Reference in New Issue
Block a user