mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 20:14:36 +00:00
* Initial plan * Add markdown rendering support to pipeline debug dialog messages with toggle button Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> * Fix code review feedback: remove conflicting styles and imports Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> * perf: styles * fix: websocket message broadcasting cross-contamination between person and group channels --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> Co-authored-by: Junyan Qin <rockchinq@gmail.com>
189 lines
7.1 KiB
Python
189 lines
7.1 KiB
Python
"""WebSocket连接管理器 - 管理多个并发WebSocket连接"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import typing
|
||
import uuid
|
||
from datetime import datetime
|
||
|
||
import pydantic
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class WebSocketConnection(pydantic.BaseModel):
|
||
"""单个WebSocket连接"""
|
||
|
||
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||
|
||
connection_id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
|
||
"""连接唯一ID"""
|
||
|
||
pipeline_uuid: str
|
||
"""关联的流水线UUID"""
|
||
|
||
session_type: str # 'person' or 'group'
|
||
"""会话类型"""
|
||
|
||
websocket: typing.Any = pydantic.Field(exclude=True)
|
||
"""WebSocket连接对象 (quart.websocket)"""
|
||
|
||
created_at: datetime = pydantic.Field(default_factory=datetime.now)
|
||
"""连接创建时间"""
|
||
|
||
last_active: datetime = pydantic.Field(default_factory=datetime.now)
|
||
"""最后活跃时间"""
|
||
|
||
send_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True)
|
||
"""发送消息队列"""
|
||
|
||
is_active: bool = True
|
||
"""连接是否活跃"""
|
||
|
||
metadata: dict = pydantic.Field(default_factory=dict)
|
||
"""连接元数据(可存储额外信息)"""
|
||
|
||
|
||
class WebSocketConnectionManager:
|
||
"""WebSocket连接管理器 - 支持多连接并发"""
|
||
|
||
def __init__(self):
|
||
self.connections: dict[str, WebSocketConnection] = {}
|
||
"""所有活跃连接 {connection_id: connection}"""
|
||
|
||
self.pipeline_connections: dict[str, set[str]] = {}
|
||
"""流水线到连接的映射 {pipeline_uuid: {connection_id, ...}}"""
|
||
|
||
self.session_connections: dict[str, set[str]] = {}
|
||
"""会话类型到连接的映射 {session_type: {connection_id, ...}}"""
|
||
|
||
self._lock = asyncio.Lock()
|
||
"""线程锁,保护并发访问"""
|
||
|
||
async def add_connection(
|
||
self,
|
||
websocket: typing.Any,
|
||
pipeline_uuid: str,
|
||
session_type: str,
|
||
metadata: dict = None,
|
||
) -> WebSocketConnection:
|
||
"""添加新的WebSocket连接"""
|
||
async with self._lock:
|
||
connection = WebSocketConnection(
|
||
pipeline_uuid=pipeline_uuid,
|
||
session_type=session_type,
|
||
websocket=websocket,
|
||
metadata=metadata or {},
|
||
)
|
||
|
||
self.connections[connection.connection_id] = connection
|
||
|
||
# 更新流水线映射
|
||
if pipeline_uuid not in self.pipeline_connections:
|
||
self.pipeline_connections[pipeline_uuid] = set()
|
||
self.pipeline_connections[pipeline_uuid].add(connection.connection_id)
|
||
|
||
# 更新会话类型映射
|
||
if session_type not in self.session_connections:
|
||
self.session_connections[session_type] = set()
|
||
self.session_connections[session_type].add(connection.connection_id)
|
||
|
||
logger.debug(
|
||
f'WebSocket connection established: {connection.connection_id} '
|
||
f'(pipeline={pipeline_uuid}, session_type={session_type})'
|
||
)
|
||
|
||
return connection
|
||
|
||
async def remove_connection(self, connection_id: str):
|
||
"""移除WebSocket连接"""
|
||
async with self._lock:
|
||
if connection_id not in self.connections:
|
||
return
|
||
|
||
connection = self.connections[connection_id]
|
||
connection.is_active = False
|
||
|
||
# 从流水线映射中移除
|
||
if connection.pipeline_uuid in self.pipeline_connections:
|
||
self.pipeline_connections[connection.pipeline_uuid].discard(connection_id)
|
||
if not self.pipeline_connections[connection.pipeline_uuid]:
|
||
del self.pipeline_connections[connection.pipeline_uuid]
|
||
|
||
# 从会话类型映射中移除
|
||
if connection.session_type in self.session_connections:
|
||
self.session_connections[connection.session_type].discard(connection_id)
|
||
if not self.session_connections[connection.session_type]:
|
||
del self.session_connections[connection.session_type]
|
||
|
||
del self.connections[connection_id]
|
||
|
||
logger.debug(f'WebSocket connection disconnected: {connection_id}')
|
||
|
||
async def get_connection(self, connection_id: str) -> typing.Optional[WebSocketConnection]:
|
||
"""获取指定连接"""
|
||
return self.connections.get(connection_id)
|
||
|
||
async def get_connections_by_pipeline(self, pipeline_uuid: str) -> list[WebSocketConnection]:
|
||
"""获取指定流水线的所有连接"""
|
||
connection_ids = self.pipeline_connections.get(pipeline_uuid, set())
|
||
return [self.connections[cid] for cid in connection_ids if cid in self.connections]
|
||
|
||
async def get_connections_by_session_type(self, session_type: str) -> list[WebSocketConnection]:
|
||
"""获取指定会话类型的所有连接"""
|
||
connection_ids = self.session_connections.get(session_type, set())
|
||
return [self.connections[cid] for cid in connection_ids if cid in self.connections]
|
||
|
||
async def broadcast_to_pipeline(self, pipeline_uuid: str, message: dict, session_type: str = None):
|
||
"""向指定流水线的所有连接广播消息
|
||
|
||
Args:
|
||
pipeline_uuid: 流水线UUID
|
||
message: 要广播的消息
|
||
session_type: 可选的会话类型过滤器,如果提供则只向匹配的session_type连接广播
|
||
"""
|
||
connections = await self.get_connections_by_pipeline(pipeline_uuid)
|
||
|
||
# 如果指定了session_type,只向匹配的连接广播
|
||
if session_type is not None:
|
||
connections = [conn for conn in connections if conn.session_type == session_type]
|
||
|
||
tasks = []
|
||
for conn in connections:
|
||
tasks.append(self.send_to_connection(conn.connection_id, message))
|
||
if tasks:
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
async def send_to_connection(self, connection_id: str, message: dict):
|
||
"""向指定连接发送消息"""
|
||
connection = await self.get_connection(connection_id)
|
||
if not connection or not connection.is_active:
|
||
logger.warning(f'Attempt to send message to invalid connection: {connection_id}')
|
||
return
|
||
|
||
try:
|
||
await connection.send_queue.put(message)
|
||
connection.last_active = datetime.now()
|
||
except Exception as e:
|
||
logger.error(f'Failed to send message to connection {connection_id}: {e}')
|
||
await self.remove_connection(connection_id)
|
||
|
||
async def update_activity(self, connection_id: str):
|
||
"""更新连接活跃时间"""
|
||
connection = await self.get_connection(connection_id)
|
||
if connection:
|
||
connection.last_active = datetime.now()
|
||
|
||
def get_stats(self) -> dict:
|
||
"""获取连接统计信息"""
|
||
return {
|
||
'total_connections': len(self.connections),
|
||
'pipelines': len(self.pipeline_connections),
|
||
'connections_by_pipeline': {k: len(v) for k, v in self.pipeline_connections.items()},
|
||
'connections_by_session_type': {k: len(v) for k, v in self.session_connections.items()},
|
||
}
|
||
|
||
|
||
# 全局连接管理器实例
|
||
ws_connection_manager = WebSocketConnectionManager()
|