mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
416 lines
16 KiB
Python
416 lines
16 KiB
Python
"""WebSocket适配器 - 支持双向通信的IM系统"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import typing
|
||
from datetime import datetime
|
||
|
||
import pydantic
|
||
|
||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||
from ...core import app
|
||
from .websocket_manager import ws_connection_manager, WebSocketConnection
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class WebSocketMessage(pydantic.BaseModel):
|
||
"""WebSocket消息格式"""
|
||
|
||
id: int
|
||
role: str # 'user' or 'assistant'
|
||
content: str
|
||
message_chain: list[dict]
|
||
timestamp: str
|
||
is_final: bool = False
|
||
connection_id: str = ''
|
||
"""发送者连接ID"""
|
||
|
||
|
||
class WebSocketSession:
|
||
"""WebSocket会话 - 管理单个会话的消息历史"""
|
||
|
||
id: str
|
||
message_lists: dict[str, list[WebSocketMessage]] = {}
|
||
"""消息列表 {pipeline_uuid: [messages]}"""
|
||
|
||
def __init__(self, id: str):
|
||
self.id = id
|
||
self.message_lists = {}
|
||
|
||
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
|
||
if pipeline_uuid not in self.message_lists:
|
||
self.message_lists[pipeline_uuid] = []
|
||
return self.message_lists[pipeline_uuid]
|
||
|
||
|
||
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||
"""WebSocket适配器 - 支持双向实时通信"""
|
||
|
||
websocket_person_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
|
||
websocket_group_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
|
||
|
||
listeners: dict[
|
||
typing.Type[platform_events.Event],
|
||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||
] = pydantic.Field(default_factory=dict, exclude=True)
|
||
|
||
ap: app.Application = pydantic.Field(exclude=True)
|
||
|
||
# 主动推送消息的队列
|
||
outbound_message_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True)
|
||
"""后端主动推送消息的队列"""
|
||
|
||
# 流式输出开关
|
||
stream_enabled: bool = pydantic.Field(default=True, exclude=True)
|
||
"""是否启用流式输出"""
|
||
|
||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||
super().__init__(
|
||
config=config,
|
||
logger=logger,
|
||
**kwargs,
|
||
)
|
||
|
||
self.websocket_person_session = WebSocketSession(id='websocketperson')
|
||
self.websocket_group_session = WebSocketSession(id='websocketgroup')
|
||
|
||
self.bot_account_id = 'websocketbot'
|
||
self.outbound_message_queue = asyncio.Queue()
|
||
self.stream_enabled = True
|
||
|
||
async def send_message(
|
||
self,
|
||
target_type: str,
|
||
target_id: str,
|
||
message: platform_message.MessageChain,
|
||
) -> dict:
|
||
"""发送消息 - 这里用于主动推送消息到前端"""
|
||
message_data = {
|
||
'type': 'bot_message',
|
||
'target_type': target_type,
|
||
'target_id': target_id,
|
||
'content': str(message),
|
||
'message_chain': [component.__dict__ for component in message],
|
||
'timestamp': datetime.now().isoformat(),
|
||
}
|
||
|
||
# 推送到所有相关连接
|
||
await self.outbound_message_queue.put(message_data)
|
||
|
||
return message_data
|
||
|
||
async def reply_message(
|
||
self,
|
||
message_source: platform_events.MessageEvent,
|
||
message: platform_message.MessageChain,
|
||
quote_origin: bool = False,
|
||
) -> dict:
|
||
"""回复消息 - 非流式"""
|
||
# 获取会话和pipeline信息
|
||
session = (
|
||
self.websocket_group_session
|
||
if isinstance(message_source, platform_events.GroupMessage)
|
||
else self.websocket_person_session
|
||
)
|
||
|
||
# 从message_source获取pipeline_uuid和connection_id
|
||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||
|
||
# 生成新的消息ID
|
||
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
|
||
|
||
message_data = WebSocketMessage(
|
||
id=msg_id,
|
||
role='assistant',
|
||
content=str(message),
|
||
message_chain=[component.__dict__ for component in message],
|
||
timestamp=datetime.now().isoformat(),
|
||
is_final=True,
|
||
)
|
||
|
||
# 保存到历史记录
|
||
session.get_message_list(pipeline_uuid).append(message_data)
|
||
|
||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||
await ws_connection_manager.broadcast_to_pipeline(
|
||
pipeline_uuid,
|
||
{
|
||
'type': 'response',
|
||
'session_type': session_type,
|
||
'data': message_data.model_dump(),
|
||
},
|
||
session_type=session_type,
|
||
)
|
||
|
||
return message_data.model_dump()
|
||
|
||
async def reply_message_chunk(
|
||
self,
|
||
message_source: platform_events.MessageEvent,
|
||
bot_message,
|
||
message: platform_message.MessageChain,
|
||
quote_origin: bool = False,
|
||
is_final: bool = False,
|
||
) -> dict:
|
||
"""回复消息块 - 流式"""
|
||
# 获取会话和pipeline信息
|
||
session = (
|
||
self.websocket_group_session
|
||
if isinstance(message_source, platform_events.GroupMessage)
|
||
else self.websocket_person_session
|
||
)
|
||
|
||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||
session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||
message_list = session.get_message_list(pipeline_uuid)
|
||
|
||
# 检查是否是新的流式消息(通过bot_message对象判断)
|
||
# 如果列表为空,或者最后一条消息已经is_final=True,则创建新消息
|
||
if not message_list or message_list[-1].is_final:
|
||
# 创建新消息
|
||
msg_id = len(message_list) + 1
|
||
message_data = WebSocketMessage(
|
||
id=msg_id,
|
||
role='assistant',
|
||
content=str(message),
|
||
message_chain=[component.__dict__ for component in message],
|
||
timestamp=datetime.now().isoformat(),
|
||
is_final=is_final and bot_message.tool_calls is None,
|
||
)
|
||
|
||
# 只有在is_final时才保存到历史记录
|
||
if is_final and bot_message.tool_calls is None:
|
||
message_list.append(message_data)
|
||
else:
|
||
# 更新最后一条消息
|
||
msg_id = message_list[-1].id
|
||
message_data = WebSocketMessage(
|
||
id=msg_id,
|
||
role='assistant',
|
||
content=str(message),
|
||
message_chain=[component.__dict__ for component in message],
|
||
timestamp=message_list[-1].timestamp, # 保持原始时间戳
|
||
is_final=is_final and bot_message.tool_calls is None,
|
||
)
|
||
|
||
# 如果是final,更新历史记录中的最后一条
|
||
if is_final and bot_message.tool_calls is None:
|
||
message_list[-1] = message_data
|
||
|
||
# 直接广播到所有该pipeline的连接,包含session_type信息
|
||
await ws_connection_manager.broadcast_to_pipeline(
|
||
pipeline_uuid,
|
||
{
|
||
'type': 'response',
|
||
'session_type': session_type,
|
||
'data': message_data.model_dump(),
|
||
},
|
||
session_type=session_type,
|
||
)
|
||
|
||
return message_data.model_dump()
|
||
|
||
async def is_stream_output_supported(self) -> bool:
|
||
"""根据stream_enabled标志返回是否支持流式输出"""
|
||
return self.stream_enabled
|
||
|
||
def register_listener(
|
||
self,
|
||
event_type: typing.Type[platform_events.Event],
|
||
func: typing.Callable[
|
||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||
],
|
||
):
|
||
"""注册事件监听器"""
|
||
self.listeners[event_type] = func
|
||
|
||
def unregister_listener(
|
||
self,
|
||
event_type: typing.Type[platform_events.Event],
|
||
func: typing.Callable[
|
||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||
],
|
||
):
|
||
"""取消注册事件监听器"""
|
||
del self.listeners[event_type]
|
||
|
||
async def is_muted(self, group_id: int) -> bool:
|
||
return False
|
||
|
||
async def run_async(self):
|
||
"""运行适配器"""
|
||
|
||
try:
|
||
while True:
|
||
# 处理主动推送消息
|
||
if not self.outbound_message_queue.empty():
|
||
try:
|
||
message = await asyncio.wait_for(self.outbound_message_queue.get(), timeout=0.1)
|
||
# 广播到所有相关连接
|
||
target_id = message.get('target_id', '')
|
||
await ws_connection_manager.broadcast_to_pipeline(target_id, message)
|
||
except asyncio.TimeoutError:
|
||
pass
|
||
|
||
await asyncio.sleep(0.1)
|
||
except asyncio.CancelledError:
|
||
raise
|
||
|
||
async def kill(self):
|
||
"""停止适配器"""
|
||
pass
|
||
|
||
async def _process_image_components(self, message_chain_obj: list):
|
||
"""
|
||
处理消息链中的图片组件,将path转换为base64
|
||
|
||
Args:
|
||
message_chain_obj: 消息链对象列表
|
||
"""
|
||
import base64
|
||
|
||
storage_mgr = self.ap.storage_mgr
|
||
|
||
for component in message_chain_obj:
|
||
if component.get('type') == 'Image' and component.get('path'):
|
||
try:
|
||
# 从storage读取文件
|
||
file_content = await storage_mgr.storage_provider.load(component['path'])
|
||
|
||
# 转换为base64
|
||
base64_str = base64.b64encode(file_content).decode('utf-8')
|
||
|
||
# 添加data URI前缀(根据文件扩展名判断MIME类型)
|
||
file_key = component['path']
|
||
if file_key.lower().endswith(('.jpg', '.jpeg')):
|
||
mime_type = 'image/jpeg'
|
||
elif file_key.lower().endswith('.png'):
|
||
mime_type = 'image/png'
|
||
elif file_key.lower().endswith('.gif'):
|
||
mime_type = 'image/gif'
|
||
elif file_key.lower().endswith('.webp'):
|
||
mime_type = 'image/webp'
|
||
else:
|
||
mime_type = 'image/png' # 默认
|
||
|
||
component['base64'] = f'data:{mime_type};base64,{base64_str}'
|
||
await storage_mgr.storage_provider.delete(component['path'])
|
||
component['path'] = ''
|
||
# 保留path字段用于后端处理,前端使用base64显示
|
||
except Exception as e:
|
||
await self.logger.error(f'加载图片文件失败 {component["path"]}: {e}')
|
||
|
||
async def handle_websocket_message(
|
||
self,
|
||
connection: WebSocketConnection,
|
||
message_data: dict,
|
||
):
|
||
"""
|
||
处理从WebSocket接收的消息
|
||
|
||
这个方法只负责接收消息、保存到历史记录、并触发事件处理
|
||
不等待任何响应,响应消息会通过reply_message/reply_message_chunk直接发送
|
||
|
||
Args:
|
||
connection: WebSocket连接对象
|
||
message_data: 消息数据,包含:
|
||
- message: 消息链
|
||
- stream: 是否启用流式输出 (可选,默认True)
|
||
"""
|
||
pipeline_uuid = connection.pipeline_uuid
|
||
session_type = connection.session_type
|
||
|
||
# 获取stream参数,默认为True
|
||
self.stream_enabled = message_data.get('stream', True)
|
||
|
||
# 选择会话
|
||
use_session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session
|
||
|
||
# 解析消息链
|
||
message_chain_obj = message_data.get('message', [])
|
||
|
||
# 处理图片组件:将path转换为base64
|
||
await self._process_image_components(message_chain_obj)
|
||
|
||
message_chain = platform_message.MessageChain.model_validate(message_chain_obj)
|
||
|
||
# 生成消息ID
|
||
message_id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||
|
||
# 保存用户消息
|
||
user_message = WebSocketMessage(
|
||
id=message_id,
|
||
role='user',
|
||
content=str(message_chain),
|
||
message_chain=message_chain_obj,
|
||
timestamp=datetime.now().isoformat(),
|
||
connection_id=connection.connection_id,
|
||
is_final=True, # 用户消息始终是完整的,非流式
|
||
)
|
||
use_session.get_message_list(pipeline_uuid).append(user_message)
|
||
|
||
# 广播用户消息到所有连接(包括发送者),包含session_type信息
|
||
await ws_connection_manager.broadcast_to_pipeline(
|
||
pipeline_uuid,
|
||
{
|
||
'type': 'user_message',
|
||
'session_type': session_type,
|
||
'data': user_message.model_dump(),
|
||
},
|
||
session_type=session_type,
|
||
)
|
||
|
||
# 添加消息源
|
||
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
|
||
|
||
# 创建事件
|
||
if session_type == 'person':
|
||
sender = platform_entities.Friend(
|
||
id=f'websocket_{connection.connection_id}', nickname='User', remark='User'
|
||
)
|
||
event = platform_events.FriendMessage(
|
||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||
)
|
||
else:
|
||
group = platform_entities.Group(
|
||
id='websocketgroup', name='Group', permission=platform_entities.Permission.Member
|
||
)
|
||
sender = platform_entities.GroupMember(
|
||
id=f'websocket_{connection.connection_id}',
|
||
member_name='User',
|
||
group=group,
|
||
permission=platform_entities.Permission.Member,
|
||
)
|
||
event = platform_events.GroupMessage(
|
||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||
)
|
||
|
||
# 设置流水线UUID
|
||
self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
|
||
|
||
# 异步触发事件处理(不等待结果)
|
||
if event.__class__ in self.listeners:
|
||
asyncio.create_task(self.listeners[event.__class__](event, self))
|
||
|
||
def get_websocket_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
|
||
"""获取消息历史"""
|
||
if session_type == 'person':
|
||
return [message.model_dump() for message in self.websocket_person_session.get_message_list(pipeline_uuid)]
|
||
else:
|
||
return [message.model_dump() for message in self.websocket_group_session.get_message_list(pipeline_uuid)]
|
||
|
||
def reset_session(self, pipeline_uuid: str, session_type: str):
|
||
"""重置会话"""
|
||
if session_type == 'person':
|
||
if pipeline_uuid in self.websocket_person_session.message_lists:
|
||
self.websocket_person_session.message_lists[pipeline_uuid] = []
|
||
else:
|
||
if pipeline_uuid in self.websocket_group_session.message_lists:
|
||
self.websocket_group_session.message_lists[pipeline_uuid] = []
|