mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-03 20:44:36 +00:00
520 lines
20 KiB
Python
520 lines
20 KiB
Python
"""WebSocket适配器 - 支持双向通信的IM系统"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import typing
|
||
from datetime import datetime
|
||
|
||
import pydantic
|
||
|
||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||
from ...core import app
|
||
from .websocket_manager import ws_connection_manager, WebSocketConnection
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class WebSocketMessage(pydantic.BaseModel):
|
||
"""WebSocket消息格式"""
|
||
|
||
id: int
|
||
role: str # 'user' or 'assistant'
|
||
content: str
|
||
message_chain: list[dict]
|
||
timestamp: str
|
||
is_final: bool = False
|
||
connection_id: str = ''
|
||
"""发送者连接ID"""
|
||
|
||
|
||
class WebSocketSession:
|
||
"""WebSocket会话 - 管理单个会话的消息历史"""
|
||
|
||
id: str
|
||
message_lists: dict[str, list[WebSocketMessage]] = {}
|
||
"""消息列表 {pipeline_uuid: [messages]}"""
|
||
stream_message_indexes: dict[str, dict[str, int]] = {}
|
||
"""流式消息索引 {pipeline_uuid: {resp_message_id: message_index}}"""
|
||
|
||
def __init__(self, id: str):
|
||
self.id = id
|
||
self.message_lists = {}
|
||
self.stream_message_indexes = {}
|
||
|
||
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
|
||
if pipeline_uuid not in self.message_lists:
|
||
self.message_lists[pipeline_uuid] = []
|
||
return self.message_lists[pipeline_uuid]
|
||
|
||
def get_stream_message_indexes(self, pipeline_uuid: str) -> dict[str, int]:
|
||
if pipeline_uuid not in self.stream_message_indexes:
|
||
self.stream_message_indexes[pipeline_uuid] = {}
|
||
return self.stream_message_indexes[pipeline_uuid]
|
||
|
||
|
||
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||
"""WebSocket适配器 - 支持双向实时通信"""
|
||
|
||
websocket_person_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
|
||
websocket_group_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
|
||
|
||
listeners: dict[
|
||
typing.Type[platform_events.Event],
|
||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||
] = pydantic.Field(default_factory=dict, exclude=True)
|
||
|
||
ap: app.Application = pydantic.Field(exclude=True)
|
||
|
||
# 主动推送消息的队列
|
||
outbound_message_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True)
|
||
"""后端主动推送消息的队列"""
|
||
|
||
# 流式输出开关
|
||
stream_enabled: bool = pydantic.Field(default=True, exclude=True)
|
||
"""是否启用流式输出"""
|
||
|
||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||
super().__init__(
|
||
config=config,
|
||
logger=logger,
|
||
**kwargs,
|
||
)
|
||
|
||
self.websocket_person_session = WebSocketSession(id='websocketperson')
|
||
self.websocket_group_session = WebSocketSession(id='websocketgroup')
|
||
|
||
self.bot_account_id = 'websocketbot'
|
||
self.outbound_message_queue = asyncio.Queue()
|
||
self.stream_enabled = True
|
||
|
||
async def send_message(
|
||
self,
|
||
target_type: str,
|
||
target_id: str,
|
||
message: platform_message.MessageChain,
|
||
) -> dict:
|
||
"""发送消息 - 这里用于主动推送消息到前端
|
||
|
||
对于 WebSocket 适配器,我们需要将消息广播到正确的 pipeline 连接。
|
||
target_id 可能是 launcher_id(如 websocket_xxx)或 pipeline_uuid。
|
||
我们需要尝试两种方式来确保消息能够送达。
|
||
"""
|
||
# 获取当前的 pipeline_uuid
|
||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||
session_type = 'group' if target_type == 'group' else 'person'
|
||
|
||
# 选择会话
|
||
session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session
|
||
|
||
# 生成唯一消息ID
|
||
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
|
||
|
||
message_data = WebSocketMessage(
|
||
id=msg_id,
|
||
role='assistant',
|
||
content=str(message),
|
||
message_chain=[component.__dict__ for component in message],
|
||
timestamp=datetime.now().isoformat(),
|
||
is_final=True,
|
||
)
|
||
|
||
# 保存到历史记录
|
||
session.get_message_list(pipeline_uuid).append(message_data)
|
||
|
||
# 直接广播到当前pipeline的连接
|
||
await ws_connection_manager.broadcast_to_pipeline(
|
||
pipeline_uuid,
|
||
{
|
||
'type': 'response',
|
||
'session_type': session_type,
|
||
'data': message_data.model_dump(),
|
||
},
|
||
session_type=session_type,
|
||
)
|
||
|
||
return message_data.model_dump()
|
||
|
||
async def reply_message(
|
||
self,
|
||
message_source: platform_events.MessageEvent,
|
||
message: platform_message.MessageChain,
|
||
quote_origin: bool = False,
|
||
) -> dict:
|
||
"""回复消息 - 非流式"""
|
||
# 获取会话和pipeline信息
|
||
session = (
|
||
self.websocket_group_session
|
||
if isinstance(message_source, platform_events.GroupMessage)
|
||
else self.websocket_person_session
|
||
)
|
||
|
||
# 从message_source获取pipeline_uuid和connection_id
|
||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||
|
||
# 生成新的消息ID
|
||
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
|
||
|
||
message_data = WebSocketMessage(
|
||
id=msg_id,
|
||
role='assistant',
|
||
content=str(message),
|
||
message_chain=[component.__dict__ for component in message],
|
||
timestamp=datetime.now().isoformat(),
|
||
is_final=True,
|
||
)
|
||
|
||
# 保存到历史记录
|
||
session.get_message_list(pipeline_uuid).append(message_data)
|
||
|
||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||
await ws_connection_manager.broadcast_to_pipeline(
|
||
pipeline_uuid,
|
||
{
|
||
'type': 'response',
|
||
'session_type': session_type,
|
||
'data': message_data.model_dump(),
|
||
},
|
||
session_type=session_type,
|
||
)
|
||
|
||
return message_data.model_dump()
|
||
|
||
async def reply_message_chunk(
|
||
self,
|
||
message_source: platform_events.MessageEvent,
|
||
bot_message,
|
||
message: platform_message.MessageChain,
|
||
quote_origin: bool = False,
|
||
is_final: bool = False,
|
||
) -> dict:
|
||
"""回复消息块 - 流式"""
|
||
# 获取会话和pipeline信息
|
||
session = (
|
||
self.websocket_group_session
|
||
if isinstance(message_source, platform_events.GroupMessage)
|
||
else self.websocket_person_session
|
||
)
|
||
|
||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||
message_list = session.get_message_list(pipeline_uuid)
|
||
stream_message_indexes = session.get_stream_message_indexes(pipeline_uuid)
|
||
|
||
# Streaming messages in LangBot have a stable resp_message_id during the same assistant reply.
|
||
# Use it as the primary key to avoid overwriting an old card from a previous reply.
|
||
resp_message_id = str(getattr(bot_message, 'resp_message_id', '') or '')
|
||
existing_index = stream_message_indexes.get(resp_message_id) if resp_message_id else None
|
||
|
||
message_is_final = is_final and bot_message.tool_calls is None
|
||
|
||
if existing_index is None or existing_index >= len(message_list):
|
||
# 创建新消息
|
||
msg_id = len(message_list) + 1
|
||
message_data = WebSocketMessage(
|
||
id=msg_id,
|
||
role='assistant',
|
||
content=str(message),
|
||
message_chain=[component.__dict__ for component in message],
|
||
timestamp=datetime.now().isoformat(),
|
||
is_final=message_is_final,
|
||
)
|
||
|
||
# 立即添加到历史记录(即使is_final=False),以便后续块可以更新它
|
||
message_list.append(message_data)
|
||
if resp_message_id:
|
||
stream_message_indexes[resp_message_id] = len(message_list) - 1
|
||
else:
|
||
# 更新同一条流式消息
|
||
old_message = message_list[existing_index]
|
||
msg_id = old_message.id
|
||
message_data = WebSocketMessage(
|
||
id=msg_id,
|
||
role='assistant',
|
||
content=str(message),
|
||
message_chain=[component.__dict__ for component in message],
|
||
timestamp=old_message.timestamp, # 保持原始时间戳
|
||
is_final=message_is_final,
|
||
)
|
||
|
||
# 更新历史记录中的对应消息
|
||
message_list[existing_index] = message_data
|
||
|
||
if message_is_final and resp_message_id:
|
||
stream_message_indexes.pop(resp_message_id, None)
|
||
|
||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||
await ws_connection_manager.broadcast_to_pipeline(
|
||
pipeline_uuid,
|
||
{
|
||
'type': 'response',
|
||
'session_type': session_type,
|
||
'data': message_data.model_dump(),
|
||
},
|
||
session_type=session_type,
|
||
)
|
||
|
||
return message_data.model_dump()
|
||
|
||
async def is_stream_output_supported(self) -> bool:
|
||
"""根据stream_enabled标志返回是否支持流式输出"""
|
||
return self.stream_enabled
|
||
|
||
def register_listener(
|
||
self,
|
||
event_type: typing.Type[platform_events.Event],
|
||
func: typing.Callable[
|
||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||
],
|
||
):
|
||
"""注册事件监听器"""
|
||
self.listeners[event_type] = func
|
||
|
||
def unregister_listener(
|
||
self,
|
||
event_type: typing.Type[platform_events.Event],
|
||
func: typing.Callable[
|
||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||
],
|
||
):
|
||
"""取消注册事件监听器"""
|
||
del self.listeners[event_type]
|
||
|
||
async def is_muted(self, group_id: int) -> bool:
|
||
return False
|
||
|
||
async def run_async(self):
|
||
"""运行适配器"""
|
||
|
||
try:
|
||
while True:
|
||
# 处理主动推送消息
|
||
if not self.outbound_message_queue.empty():
|
||
try:
|
||
message = await asyncio.wait_for(self.outbound_message_queue.get(), timeout=0.1)
|
||
# 广播到所有相关连接
|
||
target_id = message.get('target_id', '')
|
||
await ws_connection_manager.broadcast_to_pipeline(target_id, message)
|
||
except asyncio.TimeoutError:
|
||
pass
|
||
|
||
await asyncio.sleep(0.1)
|
||
except asyncio.CancelledError:
|
||
raise
|
||
|
||
async def kill(self):
|
||
"""停止适配器"""
|
||
pass
|
||
|
||
async def _process_image_components(self, message_chain_obj: list):
|
||
"""
|
||
处理消息链中的图片组件,将path转换为base64
|
||
|
||
Args:
|
||
message_chain_obj: 消息链对象列表
|
||
"""
|
||
import base64
|
||
|
||
storage_mgr = self.ap.storage_mgr
|
||
|
||
for component in message_chain_obj:
|
||
if component.get('type') == 'Image' and component.get('path'):
|
||
try:
|
||
# 从storage读取文件
|
||
file_content = await storage_mgr.storage_provider.load(component['path'])
|
||
|
||
# 转换为base64
|
||
base64_str = base64.b64encode(file_content).decode('utf-8')
|
||
|
||
# 添加data URI前缀(根据文件扩展名判断MIME类型)
|
||
file_key = component['path']
|
||
if file_key.lower().endswith(('.jpg', '.jpeg')):
|
||
mime_type = 'image/jpeg'
|
||
elif file_key.lower().endswith('.png'):
|
||
mime_type = 'image/png'
|
||
elif file_key.lower().endswith('.gif'):
|
||
mime_type = 'image/gif'
|
||
elif file_key.lower().endswith('.webp'):
|
||
mime_type = 'image/webp'
|
||
else:
|
||
mime_type = 'image/png' # 默认
|
||
|
||
component['base64'] = f'data:{mime_type};base64,{base64_str}'
|
||
await storage_mgr.storage_provider.delete(component['path'])
|
||
component['path'] = ''
|
||
# 保留path字段用于后端处理,前端使用base64显示
|
||
except Exception as e:
|
||
await self.logger.error(f'加载图片文件失败 {component["path"]}: {e}')
|
||
|
||
async def handle_websocket_message(
|
||
self,
|
||
connection: WebSocketConnection,
|
||
message_data: dict,
|
||
):
|
||
"""
|
||
处理从WebSocket接收的消息
|
||
|
||
这个方法只负责接收消息、保存到历史记录、并触发事件处理
|
||
不等待任何响应,响应消息会通过reply_message/reply_message_chunk直接发送
|
||
|
||
Args:
|
||
connection: WebSocket连接对象
|
||
message_data: 消息数据,包含:
|
||
- message: 消息链
|
||
- stream: 是否启用流式输出 (可选,默认True)
|
||
"""
|
||
pipeline_uuid = connection.pipeline_uuid
|
||
session_type = connection.session_type
|
||
is_workflow = bool(connection.metadata.get('is_workflow'))
|
||
|
||
# 获取stream参数,默认为True
|
||
self.stream_enabled = message_data.get('stream', True)
|
||
|
||
# 选择会话
|
||
use_session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session
|
||
|
||
# 解析消息链
|
||
message_chain_obj = message_data.get('message', [])
|
||
|
||
# 处理图片组件:将path转换为base64
|
||
await self._process_image_components(message_chain_obj)
|
||
|
||
message_chain = platform_message.MessageChain.model_validate(message_chain_obj)
|
||
|
||
# 生成消息ID
|
||
message_id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||
|
||
# 保存用户消息
|
||
user_message = WebSocketMessage(
|
||
id=message_id,
|
||
role='user',
|
||
content=str(message_chain),
|
||
message_chain=message_chain_obj,
|
||
timestamp=datetime.now().isoformat(),
|
||
connection_id=connection.connection_id,
|
||
is_final=True, # 用户消息始终是完整的,非流式
|
||
)
|
||
use_session.get_message_list(pipeline_uuid).append(user_message)
|
||
|
||
# 广播用户消息到所有连接(包括发送者),包含session_type信息
|
||
await ws_connection_manager.broadcast_to_pipeline(
|
||
pipeline_uuid,
|
||
{
|
||
'type': 'user_message',
|
||
'session_type': session_type,
|
||
'data': user_message.model_dump(),
|
||
},
|
||
session_type=session_type,
|
||
)
|
||
|
||
if is_workflow:
|
||
# 设置 pipeline_uuid,以便工作流节点发送消息时能正确广播
|
||
self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
|
||
|
||
message_content = str(message_chain)
|
||
message_context = {
|
||
'message_id': str(message_id),
|
||
'message_content': message_content,
|
||
'sender_id': f'websocket_{connection.connection_id}',
|
||
'sender_name': 'User',
|
||
'platform': 'websocket',
|
||
'conversation_id': connection.connection_id,
|
||
'is_group': session_type == 'group',
|
||
'group_id': 'websocketgroup' if session_type == 'group' else None,
|
||
'mentions': [],
|
||
'reply_to': None,
|
||
'raw_message': {
|
||
'message': message_chain_obj,
|
||
'connection_id': connection.connection_id,
|
||
'session_type': session_type,
|
||
},
|
||
}
|
||
|
||
trigger_data = {
|
||
'message': message_content,
|
||
'message_chain': message_chain_obj,
|
||
'session_type': session_type,
|
||
'connection_id': connection.connection_id,
|
||
'message_context': message_context,
|
||
}
|
||
|
||
try:
|
||
from ...api.http.service.workflow import WorkflowExecutionFailedError
|
||
|
||
execution_id = await self.ap.workflow_service.execute_workflow(
|
||
pipeline_uuid,
|
||
trigger_type='message',
|
||
trigger_data=trigger_data,
|
||
session_id=f'{session_type}_{connection.connection_id}',
|
||
user_id=message_context['sender_id'],
|
||
bot_id=self.ap.platform_mgr.websocket_proxy_bot.bot_entity.uuid,
|
||
)
|
||
await connection.send_queue.put(
|
||
{
|
||
'type': 'broadcast',
|
||
'message': f'Workflow execution started: {execution_id}',
|
||
}
|
||
)
|
||
except WorkflowExecutionFailedError as e:
|
||
await connection.send_queue.put({'type': 'error', 'message': e.message})
|
||
except Exception as e:
|
||
logger.error(f'Workflow websocket execution error: {e}', exc_info=True)
|
||
await connection.send_queue.put({'type': 'error', 'message': str(e)})
|
||
return
|
||
|
||
# 添加消息源
|
||
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
|
||
|
||
# 创建事件
|
||
if session_type == 'person':
|
||
sender = platform_entities.Friend(
|
||
id=f'websocket_{connection.connection_id}', nickname='User', remark='User'
|
||
)
|
||
event = platform_events.FriendMessage(
|
||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||
)
|
||
else:
|
||
group = platform_entities.Group(
|
||
id='websocketgroup', name='Group', permission=platform_entities.Permission.Member
|
||
)
|
||
sender = platform_entities.GroupMember(
|
||
id=f'websocket_{connection.connection_id}',
|
||
member_name='User',
|
||
group=group,
|
||
permission=platform_entities.Permission.Member,
|
||
)
|
||
event = platform_events.GroupMessage(
|
||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||
)
|
||
|
||
# 设置流水线UUID
|
||
self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
|
||
|
||
# 异步触发事件处理(不等待结果)
|
||
if event.__class__ in self.listeners:
|
||
asyncio.create_task(self.listeners[event.__class__](event, self))
|
||
|
||
def get_websocket_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
|
||
"""获取消息历史"""
|
||
if session_type == 'person':
|
||
return [message.model_dump() for message in self.websocket_person_session.get_message_list(pipeline_uuid)]
|
||
else:
|
||
return [message.model_dump() for message in self.websocket_group_session.get_message_list(pipeline_uuid)]
|
||
|
||
def reset_session(self, pipeline_uuid: str, session_type: str):
|
||
"""重置会话"""
|
||
if session_type == 'person':
|
||
if pipeline_uuid in self.websocket_person_session.message_lists:
|
||
self.websocket_person_session.message_lists[pipeline_uuid] = []
|
||
if pipeline_uuid in self.websocket_person_session.stream_message_indexes:
|
||
self.websocket_person_session.stream_message_indexes[pipeline_uuid] = {}
|
||
else:
|
||
if pipeline_uuid in self.websocket_group_session.message_lists:
|
||
self.websocket_group_session.message_lists[pipeline_uuid] = []
|
||
if pipeline_uuid in self.websocket_group_session.stream_message_indexes:
|
||
self.websocket_group_session.stream_message_indexes[pipeline_uuid] = {}
|