mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-12 16:56:02 +00:00
feat(agent-runner): support run steering
This commit is contained in:
@@ -103,6 +103,7 @@ class AgentRunnerCapabilities(BaseModel):
|
||||
multimodal_input: bool = False
|
||||
skill_authoring: bool = False
|
||||
interrupt: bool = False
|
||||
steering: bool = False
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
```
|
||||
@@ -113,6 +114,7 @@ class AgentRunnerCapabilities(BaseModel):
|
||||
- `multimodal_input`: runner 可以处理非纯文本 input / artifact。
|
||||
- `skill_authoring`: runner 需要 Host 提供 skill facts 以及 skill authoring tools,例如 `activate` / `register_skill`。
|
||||
- `interrupt`: runner 支持取消或中断。
|
||||
- `steering`: runner 支持在 turn 边界通过 Host pull API 消费同 conversation 在途追加消息。
|
||||
|
||||
Capabilities 字段全部是 `bool`,未知 key 禁止进入 typed manifest。早期草案里的上下文/会话类 capability 已删除;对应语义由 event-first context 和 runner-owned context 原则表达。
|
||||
|
||||
@@ -323,6 +325,7 @@ class ContextAPICapabilities(BaseModel):
|
||||
artifact_read: bool = False
|
||||
state: bool = False
|
||||
storage: bool = False
|
||||
steering_pull: bool = False
|
||||
```
|
||||
|
||||
`ContextAccess` 告诉 runner:Host inline 了什么、没 inline 什么、需要更多上下文时走哪些 API。它是 runner 按需读取上下文的入口说明,不是 Host 的业务上下文编排策略。
|
||||
@@ -483,6 +486,7 @@ await api.history_search(query, filters=None, top_k=10)
|
||||
# Event(返回稳定 event envelope 或受限 raw ref,不默认返回大 payload)
|
||||
await api.event_get(event_id)
|
||||
await api.event_page(before_cursor=None, limit=50)
|
||||
await api.steering_pull(mode="all", limit=None)
|
||||
|
||||
# Artifact(必须支持大小限制、MIME 校验、过期时间和授权范围)
|
||||
await api.artifact_metadata(artifact_id)
|
||||
@@ -563,6 +567,20 @@ class EventPage(BaseModel):
|
||||
has_more: bool = False
|
||||
total_count: int | None = None
|
||||
|
||||
class SteeringInputItem(BaseModel):
|
||||
claimed_run_id: str
|
||||
runner_id: str
|
||||
claimed_at: int | None = None
|
||||
event: AgentEventContext
|
||||
input: AgentInput
|
||||
conversation: ConversationContext | None = None
|
||||
actor: ActorContext | None = None
|
||||
subject: SubjectContext | None = None
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
class SteeringPullResult(BaseModel):
|
||||
items: list[SteeringInputItem] = []
|
||||
|
||||
class ArtifactMetadata(BaseModel):
|
||||
artifact_id: str
|
||||
artifact_type: str
|
||||
|
||||
@@ -371,6 +371,7 @@ class AgentRunContextBuilder:
|
||||
event_page_enabled = 'page' in event_perms and conversation_id is not None
|
||||
artifact_metadata_enabled = 'metadata' in artifact_perms
|
||||
artifact_read_enabled = 'read' in artifact_perms
|
||||
steering_pull_enabled = bool(getattr(descriptor.capabilities, 'steering', False)) and conversation_id is not None
|
||||
|
||||
# Determine state API availability based on binding state_policy.
|
||||
state_enabled = False
|
||||
@@ -425,5 +426,6 @@ class AgentRunContextBuilder:
|
||||
'artifact_read': artifact_read_enabled,
|
||||
'state': state_enabled,
|
||||
'storage': storage_enabled,
|
||||
'steering_pull': steering_pull_enabled,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -76,3 +76,7 @@ class AgentRunnerDescriptor(pydantic.BaseModel):
|
||||
def supports_knowledge_retrieval(self) -> bool:
|
||||
"""Check if runner supports knowledge retrieval."""
|
||||
return self.capabilities.knowledge_retrieval
|
||||
|
||||
def supports_steering(self) -> bool:
|
||||
"""Check if runner supports run steering/follow-up input."""
|
||||
return bool(getattr(self.capabilities, 'steering', False))
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import typing
|
||||
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
@@ -201,6 +202,98 @@ class AgentRunOrchestrator:
|
||||
"""Resolve runner ID for telemetry/logging without full execution."""
|
||||
return self.query_bridge.resolve_runner_id_for_telemetry(query)
|
||||
|
||||
async def try_claim_steering_from_query(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
) -> bool:
|
||||
"""Claim a query as steering input for an active run when possible."""
|
||||
plan = self.query_bridge.build_plan(query)
|
||||
event = plan.event
|
||||
binding = plan.binding
|
||||
|
||||
if event.event_type != 'message.received' or not event.conversation_id:
|
||||
return False
|
||||
|
||||
descriptor = await self.registry.get(binding.runner_id, plan.bound_plugins)
|
||||
if not descriptor.supports_steering():
|
||||
return False
|
||||
|
||||
target_run_id = await self._session_registry.find_steering_target(
|
||||
conversation_id=event.conversation_id,
|
||||
runner_id=descriptor.id,
|
||||
)
|
||||
if target_run_id is None:
|
||||
return False
|
||||
|
||||
steering_item = self._build_steering_item(event, target_run_id, descriptor.id)
|
||||
if not await self._session_registry.enqueue_steering(target_run_id, steering_item):
|
||||
return False
|
||||
|
||||
try:
|
||||
event_log_id = await self.journal.write_event_log(
|
||||
event=event,
|
||||
binding=binding,
|
||||
run_id=target_run_id,
|
||||
runner_id=descriptor.id,
|
||||
)
|
||||
await self.journal.register_input_artifacts(
|
||||
event=event,
|
||||
run_id=target_run_id,
|
||||
runner_id=descriptor.id,
|
||||
)
|
||||
await self.journal.write_user_transcript(event, event_log_id)
|
||||
except Exception as exc:
|
||||
self.ap.logger.warning(
|
||||
f'Failed to persist steering event {event.event_id} for run {target_run_id}: {exc}',
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
self.ap.logger.info(
|
||||
f'Claimed event {event.event_id} as steering input for run {target_run_id}'
|
||||
)
|
||||
return True
|
||||
|
||||
def _build_steering_item(
|
||||
self,
|
||||
event: AgentEventEnvelope,
|
||||
run_id: str,
|
||||
runner_id: str,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Build the run-scoped steering item returned by the Host pull API."""
|
||||
return {
|
||||
'claimed_run_id': run_id,
|
||||
'runner_id': runner_id,
|
||||
'claimed_at': int(time.time()),
|
||||
'event': {
|
||||
'event_id': event.event_id,
|
||||
'event_type': event.event_type,
|
||||
'event_time': event.event_time,
|
||||
'source': event.source,
|
||||
'source_event_type': event.source_event_type,
|
||||
'raw_ref': event.raw_ref.model_dump(mode='json') if event.raw_ref else None,
|
||||
'data': event.data,
|
||||
},
|
||||
'conversation': {
|
||||
'conversation_id': event.conversation_id,
|
||||
'thread_id': event.thread_id,
|
||||
'bot_id': event.bot_id,
|
||||
'workspace_id': event.workspace_id,
|
||||
},
|
||||
'actor': event.actor.model_dump(mode='json') if event.actor else None,
|
||||
'subject': event.subject.model_dump(mode='json') if event.subject else None,
|
||||
'input': {
|
||||
'text': event.input.text if event.input else None,
|
||||
'contents': [
|
||||
c.model_dump(mode='json') if hasattr(c, 'model_dump') else c
|
||||
for c in (event.input.contents if event.input else [])
|
||||
],
|
||||
'attachments': [
|
||||
a.model_dump(mode='json') if hasattr(a, 'model_dump') else a
|
||||
for a in (event.input.attachments if event.input else [])
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
async def _invoke_runner(
|
||||
self,
|
||||
descriptor: AgentRunnerDescriptor,
|
||||
|
||||
@@ -393,6 +393,22 @@ class QueryEntryAdapter:
|
||||
text = str(content)
|
||||
contents.append({'type': 'text', 'text': text})
|
||||
|
||||
if not contents:
|
||||
message_chain = getattr(query, 'message_chain', None) or []
|
||||
for component in message_chain:
|
||||
if isinstance(component, platform_message.Plain):
|
||||
component_text = getattr(component, 'text', '')
|
||||
if component_text:
|
||||
text_parts.append(component_text)
|
||||
contents.append({'type': 'text', 'text': component_text})
|
||||
elif isinstance(component, platform_message.Image):
|
||||
image_base64 = getattr(component, 'base64', None)
|
||||
image_url = getattr(component, 'url', None)
|
||||
if image_base64:
|
||||
contents.append({'type': 'image_base64', 'image_base64': image_base64})
|
||||
elif image_url:
|
||||
contents.append({'type': 'image_url', 'image_url': {'url': image_url}})
|
||||
|
||||
if text_parts:
|
||||
text = ''.join(text_parts)
|
||||
|
||||
|
||||
@@ -32,6 +32,9 @@ class RunAuthorizationSnapshot(typing.TypedDict):
|
||||
authorized_ids: dict[str, set[str]]
|
||||
|
||||
|
||||
SteeringQueueItem = dict[str, typing.Any]
|
||||
|
||||
|
||||
class AgentRunSession(typing.TypedDict):
|
||||
"""Session for an active agent runner execution.
|
||||
|
||||
@@ -51,6 +54,7 @@ class AgentRunSession(typing.TypedDict):
|
||||
plugin_identity: str # author/name
|
||||
authorization: RunAuthorizationSnapshot
|
||||
status: AgentRunSessionStatus
|
||||
steering_queue: list[SteeringQueueItem]
|
||||
|
||||
|
||||
class AgentRunSessionRegistry:
|
||||
@@ -128,6 +132,7 @@ class AgentRunSessionRegistry:
|
||||
'started_at': now,
|
||||
'last_activity_at': now,
|
||||
},
|
||||
'steering_queue': [],
|
||||
}
|
||||
|
||||
async with self._lock:
|
||||
@@ -175,6 +180,76 @@ class AgentRunSessionRegistry:
|
||||
if run_id in self._sessions:
|
||||
self._sessions[run_id]['status']['last_activity_at'] = int(time.time())
|
||||
|
||||
async def find_steering_target(
|
||||
self,
|
||||
*,
|
||||
conversation_id: str,
|
||||
runner_id: str,
|
||||
) -> str | None:
|
||||
"""Find the oldest active run that can accept steering for a conversation."""
|
||||
async with self._lock:
|
||||
candidates: list[tuple[int, str]] = []
|
||||
for run_id, session in self._sessions.items():
|
||||
authorization = session['authorization']
|
||||
if session.get('runner_id') != runner_id:
|
||||
continue
|
||||
if authorization.get('conversation_id') != conversation_id:
|
||||
continue
|
||||
if not authorization.get('available_apis', {}).get('steering_pull', False):
|
||||
continue
|
||||
candidates.append((session['status'].get('started_at', 0), run_id))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
candidates.sort(key=lambda item: item[0])
|
||||
return candidates[0][1]
|
||||
|
||||
async def enqueue_steering(
|
||||
self,
|
||||
run_id: str,
|
||||
item: SteeringQueueItem,
|
||||
) -> bool:
|
||||
"""Append one steering item to an active run queue."""
|
||||
async with self._lock:
|
||||
session = self._sessions.get(run_id)
|
||||
if session is None:
|
||||
return False
|
||||
session['steering_queue'].append(copy.deepcopy(item))
|
||||
session['status']['last_activity_at'] = int(time.time())
|
||||
return True
|
||||
|
||||
async def pull_steering(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
mode: str = 'all',
|
||||
limit: int | None = None,
|
||||
) -> list[SteeringQueueItem]:
|
||||
"""Pop pending steering items from a run queue."""
|
||||
async with self._lock:
|
||||
session = self._sessions.get(run_id)
|
||||
if session is None:
|
||||
return []
|
||||
|
||||
queue = session['steering_queue']
|
||||
if not queue:
|
||||
return []
|
||||
|
||||
normalized_mode = str(mode or 'all').lower()
|
||||
if normalized_mode in {'one', 'one-at-a-time', 'one_at_a_time'}:
|
||||
count = 1
|
||||
elif isinstance(limit, int) and limit > 0:
|
||||
count = min(limit, len(queue))
|
||||
else:
|
||||
count = len(queue)
|
||||
|
||||
count = max(0, min(count, len(queue), 100))
|
||||
items = [copy.deepcopy(item) for item in queue[:count]]
|
||||
del queue[:count]
|
||||
session['status']['last_activity_at'] = int(time.time())
|
||||
return items
|
||||
|
||||
def is_resource_allowed(
|
||||
self,
|
||||
session: AgentRunSession,
|
||||
|
||||
@@ -21,11 +21,38 @@ class Controller:
|
||||
self.ap = ap
|
||||
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
|
||||
|
||||
async def _try_claim_steering_before_session_slot(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
) -> bool:
|
||||
"""Claim steering while the normal per-session slot is still busy.
|
||||
|
||||
Follow-up input must be claimed before it waits behind the session
|
||||
semaphore; otherwise the active run can finish before the query reaches
|
||||
ChatMessageHandler.try_claim_steering_from_query.
|
||||
"""
|
||||
pipeline_uuid = query.pipeline_uuid
|
||||
if not pipeline_uuid:
|
||||
return False
|
||||
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
|
||||
if not pipeline:
|
||||
return False
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
query.session = session
|
||||
query.pipeline_config = pipeline.pipeline_entity.config
|
||||
query.variables['_pipeline_bound_plugins'] = pipeline.bound_plugins
|
||||
query.variables['_pipeline_bound_mcp_servers'] = pipeline.bound_mcp_servers
|
||||
|
||||
return await self.ap.agent_run_orchestrator.try_claim_steering_from_query(query)
|
||||
|
||||
async def consumer(self):
|
||||
"""事件处理循环"""
|
||||
try:
|
||||
while True:
|
||||
selected_query: pipeline_query.Query = None
|
||||
claimed_steering_query: pipeline_query.Query = None
|
||||
|
||||
# 取请求
|
||||
async with self.ap.query_pool:
|
||||
@@ -36,6 +63,13 @@ class Controller:
|
||||
# Debug logging removed from tight loop to prevent excessive log generation
|
||||
# that can cause memory overflow in high-traffic scenarios
|
||||
|
||||
if session._semaphore.locked():
|
||||
if await self._try_claim_steering_before_session_slot(query):
|
||||
claimed_steering_query = query
|
||||
self.ap.logger.debug(f'Claimed query {query.query_id} as steering before session slot')
|
||||
break
|
||||
continue
|
||||
|
||||
if not session._semaphore.locked():
|
||||
selected_query = query
|
||||
await session._semaphore.acquire()
|
||||
@@ -44,7 +78,12 @@ class Controller:
|
||||
|
||||
break
|
||||
|
||||
if selected_query: # 找到了
|
||||
if claimed_steering_query:
|
||||
queries.remove(claimed_steering_query)
|
||||
self.ap.query_pool.cached_queries.pop(claimed_steering_query.query_id, None)
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
continue
|
||||
elif selected_query: # 找到了
|
||||
queries.remove(selected_query)
|
||||
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
|
||||
await self.ap.query_pool.condition.wait()
|
||||
|
||||
@@ -84,15 +84,19 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
query.user_message.content = [event_ctx.event.user_message_alter]
|
||||
|
||||
text_length = 0
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
try:
|
||||
# Mark start time for telemetry
|
||||
start_ts = time.time()
|
||||
|
||||
if await self.ap.agent_run_orchestrator.try_claim_steering_from_query(query):
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
return
|
||||
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
# Create a single resp_message_id for the entire streaming response
|
||||
resp_message_id = uuid.uuid4()
|
||||
chunk_count = 0
|
||||
|
||||
@@ -1719,6 +1719,44 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
self.ap.logger.error(f'EVENT_PAGE error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'Event page error: {e}')
|
||||
|
||||
@self.action(PluginToRuntimeAction.STEERING_PULL)
|
||||
async def steering_pull(data: dict[str, Any]) -> handler.ActionResponse:
|
||||
"""Pull pending steering/follow-up inputs for the current run."""
|
||||
run_id = data.get('run_id')
|
||||
mode = data.get('mode', 'all')
|
||||
limit = data.get('limit')
|
||||
caller_plugin_identity = data.get('caller_plugin_identity')
|
||||
|
||||
if not run_id:
|
||||
return handler.ActionResponse.error(message='run_id is required')
|
||||
|
||||
if limit is not None:
|
||||
try:
|
||||
limit = int(limit)
|
||||
except (TypeError, ValueError):
|
||||
return handler.ActionResponse.error(message='limit must be an integer')
|
||||
if limit <= 0:
|
||||
return handler.ActionResponse.error(message='limit must be > 0')
|
||||
limit = min(limit, 100)
|
||||
|
||||
session, error = await _validate_agent_run_session(
|
||||
run_id,
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'Steering pull',
|
||||
api_capability='steering_pull',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
session_registry = get_session_registry()
|
||||
items = await session_registry.pull_steering(
|
||||
run_id,
|
||||
mode=str(mode or 'all'),
|
||||
limit=limit,
|
||||
)
|
||||
return handler.ActionResponse.success(data={'items': items})
|
||||
|
||||
# ================= Artifact APIs =================
|
||||
|
||||
@self.action(PluginToRuntimeAction.ARTIFACT_METADATA)
|
||||
@@ -1881,6 +1919,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'State get',
|
||||
api_capability='state',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -1927,6 +1966,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'State set',
|
||||
api_capability='state',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -1988,6 +2028,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'State delete',
|
||||
api_capability='state',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -2035,6 +2076,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'State list',
|
||||
api_capability='state',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
Reference in New Issue
Block a user