mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 07:54:19 +00:00
48905ea080
* feat(plugin): report deferred response delivery failures * style: fix ruff format issues in plugin_diagnostics and test_handler_actions --------- Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>
275 lines
8.9 KiB
Python
275 lines
8.9 KiB
Python
from __future__ import annotations
|
|
|
|
import traceback
|
|
import weakref
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PluginResponseSource:
|
|
plugin: dict[str, str]
|
|
event_name: str | None = None
|
|
is_approximate: bool = False
|
|
|
|
|
|
@dataclass
|
|
class QueryDiagnosticState:
|
|
pending_by_chain_id: dict[int, list[PluginResponseSource]] = field(default_factory=dict)
|
|
by_response_index: dict[int, list[PluginResponseSource]] = field(default_factory=dict)
|
|
finalizer: weakref.finalize | None = None
|
|
|
|
|
|
_QUERY_STATES: dict[int, QueryDiagnosticState] = {}
|
|
|
|
|
|
def record_plugin_response_source(
|
|
query: pipeline_query.Query,
|
|
response_index: int,
|
|
response_sources: list[dict[str, Any]] | None,
|
|
emitted_plugins: list[dict[str, Any]] | None = None,
|
|
event_name: str | None = None,
|
|
) -> None:
|
|
plugin_sources = _build_plugin_sources(response_sources, emitted_plugins, event_name)
|
|
if not plugin_sources:
|
|
return
|
|
state = _get_or_create_query_state(query)
|
|
state.by_response_index[response_index] = plugin_sources
|
|
|
|
|
|
def record_last_plugin_response_source(
|
|
query: pipeline_query.Query,
|
|
response_sources: list[dict[str, Any]] | None,
|
|
emitted_plugins: list[dict[str, Any]] | None = None,
|
|
event_name: str | None = None,
|
|
) -> None:
|
|
record_plugin_response_source(
|
|
query,
|
|
len(query.resp_message_chain) - 1,
|
|
response_sources,
|
|
emitted_plugins,
|
|
event_name,
|
|
)
|
|
|
|
|
|
def record_pending_plugin_response_source(
|
|
query: pipeline_query.Query,
|
|
message_chain: platform_message.MessageChain,
|
|
response_sources: list[dict[str, Any]] | None,
|
|
emitted_plugins: list[dict[str, Any]] | None = None,
|
|
event_name: str | None = None,
|
|
) -> None:
|
|
plugin_sources = _build_plugin_sources(response_sources, emitted_plugins, event_name)
|
|
if not plugin_sources:
|
|
return
|
|
state = _get_or_create_query_state(query)
|
|
state.pending_by_chain_id[id(message_chain)] = plugin_sources
|
|
|
|
|
|
def consume_pending_plugin_response_source(
|
|
query: pipeline_query.Query,
|
|
message_chain: platform_message.MessageChain,
|
|
response_index: int,
|
|
) -> None:
|
|
state = _get_query_state(query)
|
|
if state is None:
|
|
return
|
|
source = state.pending_by_chain_id.pop(id(message_chain), None)
|
|
if source is None:
|
|
return
|
|
state.by_response_index[response_index] = source
|
|
|
|
|
|
def clear_response_source(query: pipeline_query.Query, response_index: int) -> None:
|
|
state = _get_query_state(query)
|
|
if state is None:
|
|
return
|
|
state.by_response_index.pop(response_index, None)
|
|
_discard_query_state_if_empty(query)
|
|
|
|
|
|
async def notify_response_delivery_failure(
|
|
ap: Any,
|
|
query: pipeline_query.Query,
|
|
response_index: int,
|
|
message_chain: platform_message.MessageChain,
|
|
error: Exception,
|
|
) -> None:
|
|
try:
|
|
plugin_refs = _get_response_sources(query, response_index)
|
|
if not plugin_refs:
|
|
return
|
|
connector = getattr(ap, 'plugin_connector', None)
|
|
if connector is None or not hasattr(connector, 'notify_plugin_diagnostic'):
|
|
return
|
|
for source in plugin_refs:
|
|
payload = _build_delivery_failure_payload(
|
|
plugin_ref=source.plugin,
|
|
event_name=source.event_name,
|
|
is_approximate=source.is_approximate,
|
|
query=query,
|
|
response_index=response_index,
|
|
message_chain=message_chain,
|
|
error=error,
|
|
)
|
|
try:
|
|
await connector.notify_plugin_diagnostic(payload)
|
|
except Exception as diag_error:
|
|
_debug(ap, f'Plugin diagnostic forwarding failed: {diag_error}')
|
|
except Exception as diag_error:
|
|
_debug(ap, f'Plugin diagnostic forwarding skipped: {diag_error}')
|
|
|
|
|
|
def get_emitted_plugins(event_ctx: Any) -> list[dict[str, Any]]:
|
|
emitted_plugins = getattr(event_ctx, '_emitted_plugins', [])
|
|
return emitted_plugins if isinstance(emitted_plugins, list) else []
|
|
|
|
|
|
def get_response_sources(event_ctx: Any) -> list[dict[str, Any]] | None:
|
|
event_attrs = vars(event_ctx)
|
|
if '_response_sources' not in event_attrs:
|
|
return None
|
|
response_sources = event_attrs['_response_sources']
|
|
return response_sources if isinstance(response_sources, list) else []
|
|
|
|
|
|
def _get_or_create_query_state(query: pipeline_query.Query) -> QueryDiagnosticState:
|
|
query_key = id(query)
|
|
state = _QUERY_STATES.get(query_key)
|
|
if state is not None:
|
|
return state
|
|
|
|
state = QueryDiagnosticState()
|
|
try:
|
|
state.finalizer = weakref.finalize(query, _discard_query_state, query_key)
|
|
except TypeError:
|
|
state.finalizer = None
|
|
_QUERY_STATES[query_key] = state
|
|
return state
|
|
|
|
|
|
def _get_query_state(query: pipeline_query.Query) -> QueryDiagnosticState | None:
|
|
return _QUERY_STATES.get(id(query))
|
|
|
|
|
|
def _discard_query_state(query_key: int) -> None:
|
|
_QUERY_STATES.pop(query_key, None)
|
|
|
|
|
|
def _discard_query_state_if_empty(query: pipeline_query.Query) -> None:
|
|
query_key = id(query)
|
|
state = _QUERY_STATES.get(query_key)
|
|
if state is None:
|
|
return
|
|
if state.pending_by_chain_id or state.by_response_index:
|
|
return
|
|
if state.finalizer is not None:
|
|
state.finalizer.detach()
|
|
_discard_query_state(query_key)
|
|
|
|
|
|
def _get_response_sources(
|
|
query: pipeline_query.Query,
|
|
response_index: int,
|
|
) -> list[PluginResponseSource]:
|
|
state = _get_query_state(query)
|
|
if state is None:
|
|
return []
|
|
return state.by_response_index.get(response_index, [])
|
|
|
|
|
|
def _extract_plugin_ref(plugin: Any) -> dict[str, str] | None:
|
|
manifest = plugin.get('manifest') if isinstance(plugin, dict) else None
|
|
metadata = manifest.get('metadata') if isinstance(manifest, dict) else None
|
|
if not isinstance(metadata, dict):
|
|
return None
|
|
author = metadata.get('author')
|
|
name = metadata.get('name')
|
|
if not author or not name:
|
|
return None
|
|
return {'author': str(author), 'name': str(name)}
|
|
|
|
|
|
def _extract_response_source_plugin_ref(source: Any) -> dict[str, str] | None:
|
|
if not isinstance(source, dict):
|
|
return None
|
|
if source.get('kind') != 'reply_message_chain':
|
|
return None
|
|
plugin_ref = source.get('plugin')
|
|
if not isinstance(plugin_ref, dict):
|
|
return None
|
|
author = plugin_ref.get('author')
|
|
name = plugin_ref.get('name')
|
|
if not author or not name:
|
|
return None
|
|
return {'author': str(author), 'name': str(name)}
|
|
|
|
|
|
def _build_plugin_sources(
|
|
response_sources: list[dict[str, Any]] | None,
|
|
emitted_plugins: list[dict[str, Any]] | None,
|
|
event_name: str | None,
|
|
) -> list[PluginResponseSource]:
|
|
if response_sources is not None:
|
|
plugin_refs = [_extract_response_source_plugin_ref(source) for source in response_sources]
|
|
return [
|
|
PluginResponseSource(plugin=plugin, event_name=event_name) for plugin in plugin_refs if plugin is not None
|
|
]
|
|
|
|
if emitted_plugins:
|
|
plugin_refs = [_extract_plugin_ref(plugin) for plugin in emitted_plugins]
|
|
return [
|
|
PluginResponseSource(plugin=plugin, event_name=event_name, is_approximate=True)
|
|
for plugin in plugin_refs
|
|
if plugin is not None
|
|
]
|
|
return []
|
|
|
|
|
|
def _debug(ap: Any, message: str) -> None:
|
|
logger = getattr(ap, 'logger', None)
|
|
if logger is not None:
|
|
logger.debug(message)
|
|
|
|
|
|
def _build_delivery_failure_payload(
|
|
plugin_ref: dict[str, str],
|
|
event_name: str | None,
|
|
is_approximate: bool,
|
|
query: pipeline_query.Query,
|
|
response_index: int,
|
|
message_chain: platform_message.MessageChain,
|
|
error: Exception,
|
|
) -> dict[str, Any]:
|
|
details: dict[str, Any] = {
|
|
'message_component_types': [component.__class__.__name__ for component in message_chain],
|
|
'message_preview': str(message_chain)[:200],
|
|
}
|
|
if is_approximate:
|
|
details['attribution_warning'] = (
|
|
'This diagnostic was delivered to all plugins that handled the event because the '
|
|
'plugin runtime did not report the exact reply_message_chain source.'
|
|
)
|
|
|
|
return {
|
|
'level': 'ERROR',
|
|
'code': 'response_delivery_failed',
|
|
'message': 'Failed to deliver a plugin-provided response message.',
|
|
'plugin': plugin_ref,
|
|
'query': {
|
|
'query_id': query.query_id,
|
|
'event_name': event_name or query.message_event.__class__.__name__,
|
|
'stage': query.current_stage_name or 'SendResponseBackStage',
|
|
'response_index': response_index,
|
|
},
|
|
'details': details,
|
|
'delivery': {
|
|
'error_type': error.__class__.__name__,
|
|
'error_message': str(error),
|
|
'traceback': traceback.format_exception_only(type(error), error)[-1].strip(),
|
|
},
|
|
}
|