Files
LangBot/src/langbot/pkg/platform/sources/websocket_adapter.py
2025-12-11 18:54:16 +08:00

416 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""WebSocket适配器 - 支持双向通信的IM系统"""
import asyncio
import logging
import typing
from datetime import datetime
import pydantic
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ...core import app
from .websocket_manager import ws_connection_manager, WebSocketConnection
logger = logging.getLogger(__name__)
class WebSocketMessage(pydantic.BaseModel):
"""WebSocket消息格式"""
id: int
role: str # 'user' or 'assistant'
content: str
message_chain: list[dict]
timestamp: str
is_final: bool = False
connection_id: str = ''
"""发送者连接ID"""
class WebSocketSession:
"""WebSocket会话 - 管理单个会话的消息历史"""
id: str
message_lists: dict[str, list[WebSocketMessage]] = {}
"""消息列表 {pipeline_uuid: [messages]}"""
def __init__(self, id: str):
self.id = id
self.message_lists = {}
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
if pipeline_uuid not in self.message_lists:
self.message_lists[pipeline_uuid] = []
return self.message_lists[pipeline_uuid]
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""WebSocket适配器 - 支持双向实时通信"""
websocket_person_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
websocket_group_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
listeners: dict[
typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = pydantic.Field(default_factory=dict, exclude=True)
ap: app.Application = pydantic.Field(exclude=True)
# 主动推送消息的队列
outbound_message_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True)
"""后端主动推送消息的队列"""
# 流式输出开关
stream_enabled: bool = pydantic.Field(default=True, exclude=True)
"""是否启用流式输出"""
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
super().__init__(
config=config,
logger=logger,
**kwargs,
)
self.websocket_person_session = WebSocketSession(id='websocketperson')
self.websocket_group_session = WebSocketSession(id='websocketgroup')
self.bot_account_id = 'websocketbot'
self.outbound_message_queue = asyncio.Queue()
self.stream_enabled = True
async def send_message(
self,
target_type: str,
target_id: str,
message: platform_message.MessageChain,
) -> dict:
"""发送消息 - 这里用于主动推送消息到前端"""
message_data = {
'type': 'bot_message',
'target_type': target_type,
'target_id': target_id,
'content': str(message),
'message_chain': [component.__dict__ for component in message],
'timestamp': datetime.now().isoformat(),
}
# 推送到所有相关连接
await self.outbound_message_queue.put(message_data)
return message_data
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
) -> dict:
"""回复消息 - 非流式"""
# 获取会话和pipeline信息
session = (
self.websocket_group_session
if isinstance(message_source, platform_events.GroupMessage)
else self.websocket_person_session
)
# 从message_source获取pipeline_uuid和connection_id
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
# 生成新的消息ID
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
message_data = WebSocketMessage(
id=msg_id,
role='assistant',
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=datetime.now().isoformat(),
is_final=True,
)
# 保存到历史记录
session.get_message_list(pipeline_uuid).append(message_data)
# 直接广播到所有该pipeline的连接包含session_type信息
await ws_connection_manager.broadcast_to_pipeline(
pipeline_uuid,
{
'type': 'response',
'session_type': session_type,
'data': message_data.model_dump(),
},
session_type=session_type,
)
return message_data.model_dump()
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message,
message: platform_message.MessageChain,
quote_origin: bool = False,
is_final: bool = False,
) -> dict:
"""回复消息块 - 流式"""
# 获取会话和pipeline信息
session = (
self.websocket_group_session
if isinstance(message_source, platform_events.GroupMessage)
else self.websocket_person_session
)
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
message_list = session.get_message_list(pipeline_uuid)
# 检查是否是新的流式消息通过bot_message对象判断
# 如果列表为空或者最后一条消息已经is_final=True则创建新消息
if not message_list or message_list[-1].is_final:
# 创建新消息
msg_id = len(message_list) + 1
message_data = WebSocketMessage(
id=msg_id,
role='assistant',
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=datetime.now().isoformat(),
is_final=is_final and bot_message.tool_calls is None,
)
# 只有在is_final时才保存到历史记录
if is_final and bot_message.tool_calls is None:
message_list.append(message_data)
else:
# 更新最后一条消息
msg_id = message_list[-1].id
message_data = WebSocketMessage(
id=msg_id,
role='assistant',
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=message_list[-1].timestamp, # 保持原始时间戳
is_final=is_final and bot_message.tool_calls is None,
)
# 如果是final更新历史记录中的最后一条
if is_final and bot_message.tool_calls is None:
message_list[-1] = message_data
# 直接广播到所有该pipeline的连接包含session_type信息
await ws_connection_manager.broadcast_to_pipeline(
pipeline_uuid,
{
'type': 'response',
'session_type': session_type,
'data': message_data.model_dump(),
},
session_type=session_type,
)
return message_data.model_dump()
async def is_stream_output_supported(self) -> bool:
"""根据stream_enabled标志返回是否支持流式输出"""
return self.stream_enabled
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
):
"""注册事件监听器"""
self.listeners[event_type] = func
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
):
"""取消注册事件监听器"""
del self.listeners[event_type]
async def is_muted(self, group_id: int) -> bool:
return False
async def run_async(self):
"""运行适配器"""
try:
while True:
# 处理主动推送消息
if not self.outbound_message_queue.empty():
try:
message = await asyncio.wait_for(self.outbound_message_queue.get(), timeout=0.1)
# 广播到所有相关连接
target_id = message.get('target_id', '')
await ws_connection_manager.broadcast_to_pipeline(target_id, message)
except asyncio.TimeoutError:
pass
await asyncio.sleep(0.1)
except asyncio.CancelledError:
raise
async def kill(self):
"""停止适配器"""
pass
async def _process_image_components(self, message_chain_obj: list):
"""
处理消息链中的图片组件将path转换为base64
Args:
message_chain_obj: 消息链对象列表
"""
import base64
storage_mgr = self.ap.storage_mgr
for component in message_chain_obj:
if component.get('type') == 'Image' and component.get('path'):
try:
# 从storage读取文件
file_content = await storage_mgr.storage_provider.load(component['path'])
# 转换为base64
base64_str = base64.b64encode(file_content).decode('utf-8')
# 添加data URI前缀根据文件扩展名判断MIME类型
file_key = component['path']
if file_key.lower().endswith(('.jpg', '.jpeg')):
mime_type = 'image/jpeg'
elif file_key.lower().endswith('.png'):
mime_type = 'image/png'
elif file_key.lower().endswith('.gif'):
mime_type = 'image/gif'
elif file_key.lower().endswith('.webp'):
mime_type = 'image/webp'
else:
mime_type = 'image/png' # 默认
component['base64'] = f'data:{mime_type};base64,{base64_str}'
await storage_mgr.storage_provider.delete(component['path'])
component['path'] = ''
# 保留path字段用于后端处理前端使用base64显示
except Exception as e:
await self.logger.error(f'加载图片文件失败 {component["path"]}: {e}')
async def handle_websocket_message(
self,
connection: WebSocketConnection,
message_data: dict,
):
"""
处理从WebSocket接收的消息
这个方法只负责接收消息、保存到历史记录、并触发事件处理
不等待任何响应响应消息会通过reply_message/reply_message_chunk直接发送
Args:
connection: WebSocket连接对象
message_data: 消息数据,包含:
- message: 消息链
- stream: 是否启用流式输出 (可选默认True)
"""
pipeline_uuid = connection.pipeline_uuid
session_type = connection.session_type
# 获取stream参数默认为True
self.stream_enabled = message_data.get('stream', True)
# 选择会话
use_session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session
# 解析消息链
message_chain_obj = message_data.get('message', [])
# 处理图片组件将path转换为base64
await self._process_image_components(message_chain_obj)
message_chain = platform_message.MessageChain.model_validate(message_chain_obj)
# 生成消息ID
message_id = len(use_session.get_message_list(pipeline_uuid)) + 1
# 保存用户消息
user_message = WebSocketMessage(
id=message_id,
role='user',
content=str(message_chain),
message_chain=message_chain_obj,
timestamp=datetime.now().isoformat(),
connection_id=connection.connection_id,
is_final=True, # 用户消息始终是完整的,非流式
)
use_session.get_message_list(pipeline_uuid).append(user_message)
# 广播用户消息到所有连接包括发送者包含session_type信息
await ws_connection_manager.broadcast_to_pipeline(
pipeline_uuid,
{
'type': 'user_message',
'session_type': session_type,
'data': user_message.model_dump(),
},
session_type=session_type,
)
# 添加消息源
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
# 创建事件
if session_type == 'person':
sender = platform_entities.Friend(
id=f'websocket_{connection.connection_id}', nickname='User', remark='User'
)
event = platform_events.FriendMessage(
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
)
else:
group = platform_entities.Group(
id='websocketgroup', name='Group', permission=platform_entities.Permission.Member
)
sender = platform_entities.GroupMember(
id=f'websocket_{connection.connection_id}',
member_name='User',
group=group,
permission=platform_entities.Permission.Member,
)
event = platform_events.GroupMessage(
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
)
# 设置流水线UUID
self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
# 异步触发事件处理(不等待结果)
if event.__class__ in self.listeners:
asyncio.create_task(self.listeners[event.__class__](event, self))
def get_websocket_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
"""获取消息历史"""
if session_type == 'person':
return [message.model_dump() for message in self.websocket_person_session.get_message_list(pipeline_uuid)]
else:
return [message.model_dump() for message in self.websocket_group_session.get_message_list(pipeline_uuid)]
def reset_session(self, pipeline_uuid: str, session_type: str):
"""重置会话"""
if session_type == 'person':
if pipeline_uuid in self.websocket_person_session.message_lists:
self.websocket_person_session.message_lists[pipeline_uuid] = []
else:
if pipeline_uuid in self.websocket_group_session.message_lists:
self.websocket_group_session.message_lists[pipeline_uuid] = []