mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-08 14:56:03 +00:00
297 lines
10 KiB
Python
297 lines
10 KiB
Python
"""Message Aggregator Module
|
|
|
|
This module provides message aggregation/debounce functionality.
|
|
When users send multiple messages consecutively, the aggregator will wait
|
|
for a configurable delay period and merge them into a single message
|
|
before processing.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
import typing
|
|
from dataclasses import dataclass, field
|
|
|
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from ..core import app
|
|
|
|
# Maximum number of messages to buffer before forcing a flush
|
|
MAX_BUFFER_MESSAGES = 10
|
|
|
|
|
|
@dataclass
|
|
class PendingMessage:
|
|
"""A pending message waiting to be aggregated"""
|
|
|
|
bot_uuid: str
|
|
launcher_type: provider_session.LauncherTypes
|
|
launcher_id: typing.Union[int, str]
|
|
sender_id: typing.Union[int, str]
|
|
message_event: platform_events.MessageEvent
|
|
message_chain: platform_message.MessageChain
|
|
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter
|
|
pipeline_uuid: typing.Optional[str]
|
|
routed_by_rule: bool = False
|
|
timestamp: float = field(default_factory=time.time)
|
|
|
|
|
|
@dataclass
|
|
class SessionBuffer:
|
|
"""Buffer for a single session's pending messages"""
|
|
|
|
session_id: str
|
|
messages: list[PendingMessage] = field(default_factory=list)
|
|
timer_task: typing.Optional[asyncio.Task] = None
|
|
last_message_time: float = field(default_factory=time.time)
|
|
|
|
|
|
class MessageAggregator:
|
|
"""Message aggregator that buffers and merges consecutive messages
|
|
|
|
This class implements a debounce mechanism for incoming messages.
|
|
When a message arrives, it starts a timer. If more messages arrive
|
|
before the timer expires, they are buffered. When the timer expires,
|
|
all buffered messages are merged and sent to the query pool.
|
|
"""
|
|
|
|
ap: app.Application
|
|
|
|
buffers: dict[str, SessionBuffer]
|
|
"""Session ID -> SessionBuffer mapping"""
|
|
|
|
lock: asyncio.Lock
|
|
"""Lock for thread-safe buffer operations"""
|
|
|
|
def __init__(self, ap: app.Application):
|
|
self.ap = ap
|
|
self.buffers = {}
|
|
self.lock = asyncio.Lock()
|
|
|
|
def _get_session_id(
|
|
self,
|
|
bot_uuid: str,
|
|
launcher_type: provider_session.LauncherTypes,
|
|
launcher_id: typing.Union[int, str],
|
|
) -> str:
|
|
"""Generate a unique session ID"""
|
|
return f'{bot_uuid}:{launcher_type.value}:{launcher_id}'
|
|
|
|
async def _get_aggregation_config(self, pipeline_uuid: typing.Optional[str]) -> tuple[bool, float]:
|
|
"""Get aggregation configuration for a pipeline
|
|
|
|
Returns:
|
|
tuple: (enabled, delay_seconds)
|
|
"""
|
|
default_enabled = False
|
|
default_delay = 1.5
|
|
|
|
if pipeline_uuid is None:
|
|
return default_enabled, default_delay
|
|
|
|
# Get pipeline from pipeline manager
|
|
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
|
|
if pipeline is None:
|
|
return default_enabled, default_delay
|
|
|
|
config = pipeline.pipeline_entity.config or {}
|
|
trigger_config = config.get('trigger', {})
|
|
aggregation_config = trigger_config.get('message-aggregation', {})
|
|
|
|
enabled = aggregation_config.get('enabled', default_enabled)
|
|
|
|
delay_raw = aggregation_config.get('delay', default_delay)
|
|
try:
|
|
delay = float(delay_raw)
|
|
except (TypeError, ValueError):
|
|
delay = default_delay
|
|
|
|
# Clamp delay to valid range
|
|
delay = max(1.0, min(10.0, delay))
|
|
|
|
return enabled, delay
|
|
|
|
async def add_message(
|
|
self,
|
|
bot_uuid: str,
|
|
launcher_type: provider_session.LauncherTypes,
|
|
launcher_id: typing.Union[int, str],
|
|
sender_id: typing.Union[int, str],
|
|
message_event: platform_events.MessageEvent,
|
|
message_chain: platform_message.MessageChain,
|
|
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
|
pipeline_uuid: typing.Optional[str] = None,
|
|
routed_by_rule: bool = False,
|
|
) -> None:
|
|
"""Add a message to the aggregation buffer
|
|
|
|
If aggregation is disabled for the pipeline, the message is sent
|
|
directly to the query pool. Otherwise, it's buffered and will be
|
|
merged with other messages from the same session.
|
|
"""
|
|
enabled, delay = await self._get_aggregation_config(pipeline_uuid)
|
|
|
|
if not enabled:
|
|
# Aggregation disabled, send directly to query pool
|
|
await self.ap.query_pool.add_query(
|
|
bot_uuid=bot_uuid,
|
|
launcher_type=launcher_type,
|
|
launcher_id=launcher_id,
|
|
sender_id=sender_id,
|
|
message_event=message_event,
|
|
message_chain=message_chain,
|
|
adapter=adapter,
|
|
pipeline_uuid=pipeline_uuid,
|
|
routed_by_rule=routed_by_rule,
|
|
)
|
|
return
|
|
|
|
session_id = self._get_session_id(bot_uuid, launcher_type, launcher_id)
|
|
|
|
pending_msg = PendingMessage(
|
|
bot_uuid=bot_uuid,
|
|
launcher_type=launcher_type,
|
|
launcher_id=launcher_id,
|
|
sender_id=sender_id,
|
|
message_event=message_event,
|
|
message_chain=message_chain,
|
|
adapter=adapter,
|
|
pipeline_uuid=pipeline_uuid,
|
|
routed_by_rule=routed_by_rule,
|
|
)
|
|
|
|
force_flush = False
|
|
async with self.lock:
|
|
if session_id in self.buffers:
|
|
buffer = self.buffers[session_id]
|
|
# Cancel existing timer (just cancel, don't await inside lock)
|
|
if buffer.timer_task and not buffer.timer_task.done():
|
|
buffer.timer_task.cancel()
|
|
buffer.messages.append(pending_msg)
|
|
else:
|
|
buffer = SessionBuffer(
|
|
session_id=session_id,
|
|
messages=[pending_msg],
|
|
)
|
|
self.buffers[session_id] = buffer
|
|
|
|
buffer.last_message_time = time.time()
|
|
|
|
# Check if buffer reached max capacity
|
|
if len(buffer.messages) >= MAX_BUFFER_MESSAGES:
|
|
force_flush = True
|
|
else:
|
|
# Start new timer
|
|
buffer.timer_task = asyncio.create_task(self._delayed_flush(session_id, delay))
|
|
|
|
if force_flush:
|
|
await self._flush_buffer(session_id)
|
|
|
|
async def _delayed_flush(self, session_id: str, delay: float) -> None:
|
|
"""Wait for delay then flush the buffer"""
|
|
try:
|
|
await asyncio.sleep(delay)
|
|
await self._flush_buffer(session_id)
|
|
except asyncio.CancelledError:
|
|
# Timer was cancelled, new message arrived
|
|
pass
|
|
|
|
async def _flush_buffer(self, session_id: str) -> None:
|
|
"""Flush the buffer for a session, merging all messages"""
|
|
async with self.lock:
|
|
buffer = self.buffers.pop(session_id, None)
|
|
|
|
if buffer is None or not buffer.messages:
|
|
return
|
|
|
|
if len(buffer.messages) == 1:
|
|
# Only one message, no need to merge
|
|
msg = buffer.messages[0]
|
|
await self.ap.query_pool.add_query(
|
|
bot_uuid=msg.bot_uuid,
|
|
launcher_type=msg.launcher_type,
|
|
launcher_id=msg.launcher_id,
|
|
sender_id=msg.sender_id,
|
|
message_event=msg.message_event,
|
|
message_chain=msg.message_chain,
|
|
adapter=msg.adapter,
|
|
pipeline_uuid=msg.pipeline_uuid,
|
|
routed_by_rule=msg.routed_by_rule,
|
|
)
|
|
return
|
|
|
|
# Merge multiple messages
|
|
merged_msg = self._merge_messages(buffer.messages)
|
|
await self.ap.query_pool.add_query(
|
|
bot_uuid=merged_msg.bot_uuid,
|
|
launcher_type=merged_msg.launcher_type,
|
|
launcher_id=merged_msg.launcher_id,
|
|
sender_id=merged_msg.sender_id,
|
|
message_event=merged_msg.message_event,
|
|
message_chain=merged_msg.message_chain,
|
|
adapter=merged_msg.adapter,
|
|
pipeline_uuid=merged_msg.pipeline_uuid,
|
|
routed_by_rule=merged_msg.routed_by_rule,
|
|
)
|
|
|
|
def _merge_messages(self, messages: list[PendingMessage]) -> PendingMessage:
|
|
"""Merge multiple messages into one
|
|
|
|
The merged message uses the first message as base and combines
|
|
all message chains with newline separators.
|
|
The original message_event is kept unmodified to preserve
|
|
message metadata (message_id, etc.) for reply/quote.
|
|
"""
|
|
if len(messages) == 1:
|
|
return messages[0]
|
|
|
|
base_msg = messages[0]
|
|
|
|
# Build merged message chain
|
|
merged_chain = platform_message.MessageChain([])
|
|
|
|
for i, msg in enumerate(messages):
|
|
if i > 0:
|
|
# Add newline separator between messages
|
|
merged_chain.append(platform_message.Plain(text='\n'))
|
|
|
|
# Copy all components from this message
|
|
for component in msg.message_chain:
|
|
merged_chain.append(component)
|
|
|
|
# Keep message_event unmodified (preserves original message_id and
|
|
# metadata for reply/quote), only pass merged chain separately
|
|
return PendingMessage(
|
|
bot_uuid=base_msg.bot_uuid,
|
|
launcher_type=base_msg.launcher_type,
|
|
launcher_id=base_msg.launcher_id,
|
|
sender_id=base_msg.sender_id,
|
|
message_event=base_msg.message_event,
|
|
message_chain=merged_chain,
|
|
adapter=base_msg.adapter,
|
|
pipeline_uuid=base_msg.pipeline_uuid,
|
|
routed_by_rule=any(msg.routed_by_rule for msg in messages),
|
|
)
|
|
|
|
async def flush_all(self) -> None:
|
|
"""Flush all pending buffers immediately
|
|
|
|
This is useful during shutdown to ensure no messages are lost.
|
|
"""
|
|
# Snapshot session IDs and cancel all timers under lock
|
|
async with self.lock:
|
|
session_ids = list(self.buffers.keys())
|
|
for sid in session_ids:
|
|
buffer = self.buffers.get(sid)
|
|
if buffer and buffer.timer_task and not buffer.timer_task.done():
|
|
buffer.timer_task.cancel()
|
|
|
|
# Flush each buffer outside the lock
|
|
for session_id in session_ids:
|
|
await self._flush_buffer(session_id)
|