feat(agent-runner): support run steering

This commit is contained in:
huanghuoguoguo
2026-06-11 23:03:44 +08:00
parent 9a231927ad
commit c3fa4b6a68
9 changed files with 299 additions and 6 deletions

View File

@@ -103,6 +103,7 @@ class AgentRunnerCapabilities(BaseModel):
multimodal_input: bool = False
skill_authoring: bool = False
interrupt: bool = False
steering: bool = False
model_config = ConfigDict(extra="forbid")
```
@@ -113,6 +114,7 @@ class AgentRunnerCapabilities(BaseModel):
- `multimodal_input`: runner 可以处理非纯文本 input / artifact。
- `skill_authoring`: runner 需要 Host 提供 skill facts 以及 skill authoring tools例如 `activate` / `register_skill`
- `interrupt`: runner 支持取消或中断。
- `steering`: runner 支持在 turn 边界通过 Host pull API 消费同 conversation 在途追加消息。
Capabilities 字段全部是 `bool`,未知 key 禁止进入 typed manifest。早期草案里的上下文/会话类 capability 已删除;对应语义由 event-first context 和 runner-owned context 原则表达。
@@ -323,6 +325,7 @@ class ContextAPICapabilities(BaseModel):
artifact_read: bool = False
state: bool = False
storage: bool = False
steering_pull: bool = False
```
`ContextAccess` 告诉 runnerHost inline 了什么、没 inline 什么、需要更多上下文时走哪些 API。它是 runner 按需读取上下文的入口说明,不是 Host 的业务上下文编排策略。
@@ -483,6 +486,7 @@ await api.history_search(query, filters=None, top_k=10)
# Event返回稳定 event envelope 或受限 raw ref不默认返回大 payload
await api.event_get(event_id)
await api.event_page(before_cursor=None, limit=50)
await api.steering_pull(mode="all", limit=None)
# Artifact必须支持大小限制、MIME 校验、过期时间和授权范围)
await api.artifact_metadata(artifact_id)
@@ -563,6 +567,20 @@ class EventPage(BaseModel):
has_more: bool = False
total_count: int | None = None
class SteeringInputItem(BaseModel):
claimed_run_id: str
runner_id: str
claimed_at: int | None = None
event: AgentEventContext
input: AgentInput
conversation: ConversationContext | None = None
actor: ActorContext | None = None
subject: SubjectContext | None = None
metadata: dict[str, Any] = {}
class SteeringPullResult(BaseModel):
items: list[SteeringInputItem] = []
class ArtifactMetadata(BaseModel):
artifact_id: str
artifact_type: str

View File

@@ -371,6 +371,7 @@ class AgentRunContextBuilder:
event_page_enabled = 'page' in event_perms and conversation_id is not None
artifact_metadata_enabled = 'metadata' in artifact_perms
artifact_read_enabled = 'read' in artifact_perms
steering_pull_enabled = bool(getattr(descriptor.capabilities, 'steering', False)) and conversation_id is not None
# Determine state API availability based on binding state_policy.
state_enabled = False
@@ -425,5 +426,6 @@ class AgentRunContextBuilder:
'artifact_read': artifact_read_enabled,
'state': state_enabled,
'storage': storage_enabled,
'steering_pull': steering_pull_enabled,
},
}

View File

@@ -76,3 +76,7 @@ class AgentRunnerDescriptor(pydantic.BaseModel):
def supports_knowledge_retrieval(self) -> bool:
"""Check if runner supports knowledge retrieval."""
return self.capabilities.knowledge_retrieval
def supports_steering(self) -> bool:
"""Check if runner supports run steering/follow-up input."""
return bool(getattr(self.capabilities, 'steering', False))

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import time
import typing
from langbot_plugin.api.entities.builtin.provider import message as provider_message
@@ -201,6 +202,98 @@ class AgentRunOrchestrator:
"""Resolve runner ID for telemetry/logging without full execution."""
return self.query_bridge.resolve_runner_id_for_telemetry(query)
async def try_claim_steering_from_query(
self,
query: pipeline_query.Query,
) -> bool:
"""Claim a query as steering input for an active run when possible."""
plan = self.query_bridge.build_plan(query)
event = plan.event
binding = plan.binding
if event.event_type != 'message.received' or not event.conversation_id:
return False
descriptor = await self.registry.get(binding.runner_id, plan.bound_plugins)
if not descriptor.supports_steering():
return False
target_run_id = await self._session_registry.find_steering_target(
conversation_id=event.conversation_id,
runner_id=descriptor.id,
)
if target_run_id is None:
return False
steering_item = self._build_steering_item(event, target_run_id, descriptor.id)
if not await self._session_registry.enqueue_steering(target_run_id, steering_item):
return False
try:
event_log_id = await self.journal.write_event_log(
event=event,
binding=binding,
run_id=target_run_id,
runner_id=descriptor.id,
)
await self.journal.register_input_artifacts(
event=event,
run_id=target_run_id,
runner_id=descriptor.id,
)
await self.journal.write_user_transcript(event, event_log_id)
except Exception as exc:
self.ap.logger.warning(
f'Failed to persist steering event {event.event_id} for run {target_run_id}: {exc}',
exc_info=True,
)
self.ap.logger.info(
f'Claimed event {event.event_id} as steering input for run {target_run_id}'
)
return True
def _build_steering_item(
self,
event: AgentEventEnvelope,
run_id: str,
runner_id: str,
) -> dict[str, typing.Any]:
"""Build the run-scoped steering item returned by the Host pull API."""
return {
'claimed_run_id': run_id,
'runner_id': runner_id,
'claimed_at': int(time.time()),
'event': {
'event_id': event.event_id,
'event_type': event.event_type,
'event_time': event.event_time,
'source': event.source,
'source_event_type': event.source_event_type,
'raw_ref': event.raw_ref.model_dump(mode='json') if event.raw_ref else None,
'data': event.data,
},
'conversation': {
'conversation_id': event.conversation_id,
'thread_id': event.thread_id,
'bot_id': event.bot_id,
'workspace_id': event.workspace_id,
},
'actor': event.actor.model_dump(mode='json') if event.actor else None,
'subject': event.subject.model_dump(mode='json') if event.subject else None,
'input': {
'text': event.input.text if event.input else None,
'contents': [
c.model_dump(mode='json') if hasattr(c, 'model_dump') else c
for c in (event.input.contents if event.input else [])
],
'attachments': [
a.model_dump(mode='json') if hasattr(a, 'model_dump') else a
for a in (event.input.attachments if event.input else [])
],
},
}
async def _invoke_runner(
self,
descriptor: AgentRunnerDescriptor,

View File

@@ -393,6 +393,22 @@ class QueryEntryAdapter:
text = str(content)
contents.append({'type': 'text', 'text': text})
if not contents:
message_chain = getattr(query, 'message_chain', None) or []
for component in message_chain:
if isinstance(component, platform_message.Plain):
component_text = getattr(component, 'text', '')
if component_text:
text_parts.append(component_text)
contents.append({'type': 'text', 'text': component_text})
elif isinstance(component, platform_message.Image):
image_base64 = getattr(component, 'base64', None)
image_url = getattr(component, 'url', None)
if image_base64:
contents.append({'type': 'image_base64', 'image_base64': image_base64})
elif image_url:
contents.append({'type': 'image_url', 'image_url': {'url': image_url}})
if text_parts:
text = ''.join(text_parts)

View File

@@ -32,6 +32,9 @@ class RunAuthorizationSnapshot(typing.TypedDict):
authorized_ids: dict[str, set[str]]
SteeringQueueItem = dict[str, typing.Any]
class AgentRunSession(typing.TypedDict):
"""Session for an active agent runner execution.
@@ -51,6 +54,7 @@ class AgentRunSession(typing.TypedDict):
plugin_identity: str # author/name
authorization: RunAuthorizationSnapshot
status: AgentRunSessionStatus
steering_queue: list[SteeringQueueItem]
class AgentRunSessionRegistry:
@@ -128,6 +132,7 @@ class AgentRunSessionRegistry:
'started_at': now,
'last_activity_at': now,
},
'steering_queue': [],
}
async with self._lock:
@@ -175,6 +180,76 @@ class AgentRunSessionRegistry:
if run_id in self._sessions:
self._sessions[run_id]['status']['last_activity_at'] = int(time.time())
async def find_steering_target(
self,
*,
conversation_id: str,
runner_id: str,
) -> str | None:
"""Find the oldest active run that can accept steering for a conversation."""
async with self._lock:
candidates: list[tuple[int, str]] = []
for run_id, session in self._sessions.items():
authorization = session['authorization']
if session.get('runner_id') != runner_id:
continue
if authorization.get('conversation_id') != conversation_id:
continue
if not authorization.get('available_apis', {}).get('steering_pull', False):
continue
candidates.append((session['status'].get('started_at', 0), run_id))
if not candidates:
return None
candidates.sort(key=lambda item: item[0])
return candidates[0][1]
async def enqueue_steering(
self,
run_id: str,
item: SteeringQueueItem,
) -> bool:
"""Append one steering item to an active run queue."""
async with self._lock:
session = self._sessions.get(run_id)
if session is None:
return False
session['steering_queue'].append(copy.deepcopy(item))
session['status']['last_activity_at'] = int(time.time())
return True
async def pull_steering(
self,
run_id: str,
*,
mode: str = 'all',
limit: int | None = None,
) -> list[SteeringQueueItem]:
"""Pop pending steering items from a run queue."""
async with self._lock:
session = self._sessions.get(run_id)
if session is None:
return []
queue = session['steering_queue']
if not queue:
return []
normalized_mode = str(mode or 'all').lower()
if normalized_mode in {'one', 'one-at-a-time', 'one_at_a_time'}:
count = 1
elif isinstance(limit, int) and limit > 0:
count = min(limit, len(queue))
else:
count = len(queue)
count = max(0, min(count, len(queue), 100))
items = [copy.deepcopy(item) for item in queue[:count]]
del queue[:count]
session['status']['last_activity_at'] = int(time.time())
return items
def is_resource_allowed(
self,
session: AgentRunSession,

View File

@@ -21,11 +21,38 @@ class Controller:
self.ap = ap
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
async def _try_claim_steering_before_session_slot(
self,
query: pipeline_query.Query,
) -> bool:
"""Claim steering while the normal per-session slot is still busy.
Follow-up input must be claimed before it waits behind the session
semaphore; otherwise the active run can finish before the query reaches
ChatMessageHandler.try_claim_steering_from_query.
"""
pipeline_uuid = query.pipeline_uuid
if not pipeline_uuid:
return False
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
if not pipeline:
return False
session = await self.ap.sess_mgr.get_session(query)
query.session = session
query.pipeline_config = pipeline.pipeline_entity.config
query.variables['_pipeline_bound_plugins'] = pipeline.bound_plugins
query.variables['_pipeline_bound_mcp_servers'] = pipeline.bound_mcp_servers
return await self.ap.agent_run_orchestrator.try_claim_steering_from_query(query)
async def consumer(self):
"""事件处理循环"""
try:
while True:
selected_query: pipeline_query.Query = None
claimed_steering_query: pipeline_query.Query = None
# 取请求
async with self.ap.query_pool:
@@ -36,6 +63,13 @@ class Controller:
# Debug logging removed from tight loop to prevent excessive log generation
# that can cause memory overflow in high-traffic scenarios
if session._semaphore.locked():
if await self._try_claim_steering_before_session_slot(query):
claimed_steering_query = query
self.ap.logger.debug(f'Claimed query {query.query_id} as steering before session slot')
break
continue
if not session._semaphore.locked():
selected_query = query
await session._semaphore.acquire()
@@ -44,7 +78,12 @@ class Controller:
break
if selected_query: # 找到了
if claimed_steering_query:
queries.remove(claimed_steering_query)
self.ap.query_pool.cached_queries.pop(claimed_steering_query.query_id, None)
self.ap.query_pool.condition.notify_all()
continue
elif selected_query: # 找到了
queries.remove(selected_query)
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
await self.ap.query_pool.condition.wait()

View File

@@ -83,15 +83,19 @@ class ChatMessageHandler(handler.MessageHandler):
query.user_message.content = [event_ctx.event.user_message_alter]
text_length = 0
try:
is_stream = await query.adapter.is_stream_output_supported()
except AttributeError:
is_stream = False
try:
# Mark start time for telemetry
start_ts = time.time()
if await self.ap.agent_run_orchestrator.try_claim_steering_from_query(query):
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
return
try:
is_stream = await query.adapter.is_stream_output_supported()
except AttributeError:
is_stream = False
# Create a single resp_message_id for the entire streaming response
resp_message_id = uuid.uuid4()
chunk_count = 0

View File

@@ -1719,6 +1719,44 @@ class RuntimeConnectionHandler(handler.Handler):
self.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True)
return handler.ActionResponse.error(message=f'Event page error: {e}')
@self.action(PluginToRuntimeAction.STEERING_PULL)
async def steering_pull(data: dict[str, Any]) -> handler.ActionResponse:
"""Pull pending steering/follow-up inputs for the current run."""
run_id = data.get('run_id')
mode = data.get('mode', 'all')
limit = data.get('limit')
caller_plugin_identity = data.get('caller_plugin_identity')
if not run_id:
return handler.ActionResponse.error(message='run_id is required')
if limit is not None:
try:
limit = int(limit)
except (TypeError, ValueError):
return handler.ActionResponse.error(message='limit must be an integer')
if limit <= 0:
return handler.ActionResponse.error(message='limit must be > 0')
limit = min(limit, 100)
session, error = await _validate_agent_run_session(
run_id,
caller_plugin_identity,
self.ap,
'Steering pull',
api_capability='steering_pull',
)
if error:
return error
session_registry = get_session_registry()
items = await session_registry.pull_steering(
run_id,
mode=str(mode or 'all'),
limit=limit,
)
return handler.ActionResponse.success(data={'items': items})
# ================= Artifact APIs =================
@self.action(PluginToRuntimeAction.ARTIFACT_METADATA)
@@ -1881,6 +1919,7 @@ class RuntimeConnectionHandler(handler.Handler):
caller_plugin_identity,
self.ap,
'State get',
api_capability='state',
)
if error:
return error
@@ -1927,6 +1966,7 @@ class RuntimeConnectionHandler(handler.Handler):
caller_plugin_identity,
self.ap,
'State set',
api_capability='state',
)
if error:
return error
@@ -1988,6 +2028,7 @@ class RuntimeConnectionHandler(handler.Handler):
caller_plugin_identity,
self.ap,
'State delete',
api_capability='state',
)
if error:
return error
@@ -2035,6 +2076,7 @@ class RuntimeConnectionHandler(handler.Handler):
caller_plugin_identity,
self.ap,
'State list',
api_capability='state',
)
if error:
return error