refactor: switch webchat from sse to websocket (#1808)

* refactor: switch webchat from sse to websocket

* perf: image preview dialog

* chore: remove console.log
This commit is contained in:
Junyan Qin (Chin)
2025-11-28 14:54:01 +08:00
committed by GitHub
parent 348620ac0a
commit d09b823c49
39 changed files with 2656 additions and 783 deletions

View File

@@ -28,8 +28,56 @@ class FilesRouterGroup(group.RouterGroup):
return quart.Response(image_bytes, mimetype=mime_type)
@self.route('/images', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def upload_image() -> quart.Response:
request = quart.request
# Check file size limit before reading the file
content_length = request.content_length
if content_length and content_length > group.MAX_FILE_SIZE:
return self.fail(400, 'Image size exceeds 10MB limit.')
# get file bytes from 'file'
files = await request.files
if 'file' not in files:
return self.fail(400, 'No image file provided')
file = files['file']
assert isinstance(file, quart.datastructures.FileStorage)
file_bytes = await asyncio.to_thread(file.stream.read)
# Double-check actual file size after reading
if len(file_bytes) > group.MAX_FILE_SIZE:
return self.fail(400, 'Image size exceeds 10MB limit.')
# Validate image file extension
allowed_extensions = {'jpg', 'jpeg', 'png', 'gif', 'webp'}
if '.' in file.filename:
file_name, extension = file.filename.rsplit('.', 1)
extension = extension.lower()
else:
return self.fail(400, 'Invalid image file: no file extension')
if extension not in allowed_extensions:
return self.fail(400, f'Invalid image format. Allowed formats: {", ".join(allowed_extensions)}')
# check if file name contains '/' or '\'
if '/' in file_name or '\\' in file_name:
return self.fail(400, 'File name contains invalid characters')
file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
# save file to storage
await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes)
return self.success(
data={
'file_key': file_key,
}
)
@self.route('/documents', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> quart.Response:
async def upload_document() -> quart.Response:
request = quart.request
# Check file size limit before reading the file

View File

@@ -1,109 +0,0 @@
import json
import quart
from ... import group
@group.group_class('webchat', '/api/v1/pipelines/<pipeline_uuid>/chat')
class WebChatDebugRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/send', methods=['POST'])
async def send_message(pipeline_uuid: str) -> str:
"""Send a message to the pipeline for debugging"""
async def stream_generator(generator):
yield 'data: {"type": "start"}\n\n'
async for message in generator:
yield f'data: {json.dumps({"message": message})}\n\n'
yield 'data: {"type": "end"}\n\n'
try:
data = await quart.request.get_json()
session_type = data.get('session_type', 'person')
message_chain_obj = data.get('message', [])
is_stream = data.get('is_stream', False)
if not message_chain_obj:
return self.http_status(400, -1, 'message is required')
if session_type not in ['person', 'group']:
return self.http_status(400, -1, 'session_type must be person or group')
webchat_adapter = self.ap.platform_mgr.webchat_proxy_bot.adapter
if not webchat_adapter:
return self.http_status(404, -1, 'WebChat adapter not found')
if is_stream:
generator = webchat_adapter.send_webchat_message(
pipeline_uuid, session_type, message_chain_obj, is_stream
)
# 设置正确的响应头
headers = {
'Content-Type': 'text/event-stream',
'Transfer-Encoding': 'chunked',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
}
return quart.Response(stream_generator(generator), mimetype='text/event-stream', headers=headers)
else: # non-stream
result = None
async for message in webchat_adapter.send_webchat_message(
pipeline_uuid, session_type, message_chain_obj
):
result = message
if result is not None:
return self.success(
data={
'message': result,
}
)
else:
return self.http_status(400, -1, 'message is required')
except Exception as e:
return self.http_status(500, -1, f'Internal server error: {str(e)}')
@self.route('/messages/<session_type>', methods=['GET'])
async def get_messages(pipeline_uuid: str, session_type: str) -> str:
"""Get the message history of the pipeline for debugging"""
try:
if session_type not in ['person', 'group']:
return self.http_status(400, -1, 'session_type must be person or group')
webchat_adapter = self.ap.platform_mgr.webchat_proxy_bot.adapter
if not webchat_adapter:
return self.http_status(404, -1, 'WebChat adapter not found')
messages = webchat_adapter.get_webchat_messages(pipeline_uuid, session_type)
return self.success(data={'messages': messages})
except Exception as e:
return self.http_status(500, -1, f'Internal server error: {str(e)}')
@self.route('/reset/<session_type>', methods=['POST'])
async def reset_session(session_type: str) -> str:
"""Reset the debug session"""
try:
if session_type not in ['person', 'group']:
return self.http_status(400, -1, 'session_type must be person or group')
webchat_adapter = None
for bot in self.ap.platform_mgr.bots:
if hasattr(bot.adapter, '__class__') and bot.adapter.__class__.__name__ == 'WebChatAdapter':
webchat_adapter = bot.adapter
break
if not webchat_adapter:
return self.http_status(404, -1, 'WebChat adapter not found')
webchat_adapter.reset_debug_session(session_type)
return self.success(data={'message': 'Session reset successfully'})
except Exception as e:
return self.http_status(500, -1, f'Internal server error: {str(e)}')

View File

@@ -0,0 +1,243 @@
"""WebSocket聊天路由 - 支持双向实时通信"""
import asyncio
import datetime
import json
import logging
import quart
from ... import group
from ......platform.sources.websocket_manager import ws_connection_manager
logger = logging.getLogger(__name__)
@group.group_class('websocket_chat', '/api/v1/pipelines/<pipeline_uuid>/ws')
class WebSocketChatRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
# 直接使用 quart_app 注册 WebSocket 路由
@self.quart_app.websocket(self.path + '/connect')
async def websocket_connect(pipeline_uuid: str):
"""
建立WebSocket连接
URL参数:
- pipeline_uuid: 流水线UUID
- session_type: 会话类型 (person/group)
"""
try:
# 获取参数 - 在WebSocket上下文中使用 quart.websocket.args
session_type = quart.websocket.args.get('session_type', 'person')
if session_type not in ['person', 'group']:
await quart.websocket.send(
json.dumps({'type': 'error', 'message': 'session_type must be person or group'})
)
return
# 获取WebSocket适配器
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
if not websocket_adapter:
await quart.websocket.send(json.dumps({'type': 'error', 'message': 'WebSocket adapter not found'}))
return
# 注册连接
connection = await ws_connection_manager.add_connection(
websocket=quart.websocket._get_current_object(),
pipeline_uuid=pipeline_uuid,
session_type=session_type,
metadata={'user_agent': quart.websocket.headers.get('User-Agent', '')},
)
# 发送连接成功消息
await quart.websocket.send(
json.dumps(
{
'type': 'connected',
'connection_id': connection.connection_id,
'pipeline_uuid': pipeline_uuid,
'session_type': session_type,
'timestamp': connection.created_at.isoformat(),
}
)
)
logger.debug(
f'WebSocket connection established: {connection.connection_id} '
f'(pipeline={pipeline_uuid}, session_type={session_type})'
)
# 创建接收和发送任务
receive_task = asyncio.create_task(self._handle_receive(connection, websocket_adapter))
send_task = asyncio.create_task(self._handle_send(connection))
# 等待任务完成
try:
await asyncio.gather(receive_task, send_task)
except Exception as e:
logger.error(f'WebSocket task execution error: {e}')
finally:
# 清理连接
await ws_connection_manager.remove_connection(connection.connection_id)
logger.debug(f'WebSocket connection cleaned: {connection.connection_id}')
except Exception as e:
logger.error(f'WebSocket connection error: {e}', exc_info=True)
try:
await quart.websocket.send(json.dumps({'type': 'error', 'message': str(e)}))
except:
pass
@self.route('/messages/<session_type>', methods=['GET'])
async def get_messages(pipeline_uuid: str, session_type: str) -> str:
"""获取消息历史"""
try:
if session_type not in ['person', 'group']:
return self.http_status(400, -1, 'session_type must be person or group')
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
if not websocket_adapter:
return self.http_status(404, -1, 'WebSocket adapter not found')
messages = websocket_adapter.get_websocket_messages(pipeline_uuid, session_type)
return self.success(data={'messages': messages})
except Exception as e:
return self.http_status(500, -1, f'Internal server error: {str(e)}')
@self.route('/reset/<session_type>', methods=['POST'])
async def reset_session(pipeline_uuid: str, session_type: str) -> str:
"""重置会话"""
try:
if session_type not in ['person', 'group']:
return self.http_status(400, -1, 'session_type must be person or group')
websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter
if not websocket_adapter:
return self.http_status(404, -1, 'WebSocket adapter not found')
websocket_adapter.reset_session(pipeline_uuid, session_type)
return self.success(data={'message': 'Session reset successfully'})
except Exception as e:
return self.http_status(500, -1, f'Internal server error: {str(e)}')
@self.route('/connections', methods=['GET'])
async def get_connections(pipeline_uuid: str) -> str:
"""获取当前连接统计"""
try:
stats = ws_connection_manager.get_stats()
connections = await ws_connection_manager.get_connections_by_pipeline(pipeline_uuid)
return self.success(
data={
'stats': stats,
'connections': [
{
'connection_id': conn.connection_id,
'session_type': conn.session_type,
'created_at': conn.created_at.isoformat(),
'last_active': conn.last_active.isoformat(),
'is_active': conn.is_active,
}
for conn in connections
],
}
)
except Exception as e:
return self.http_status(500, -1, f'Internal server error: {str(e)}')
@self.route('/broadcast', methods=['POST'])
async def broadcast_message(pipeline_uuid: str) -> str:
"""向所有连接广播消息(后端主动推送)"""
try:
data = await quart.request.get_json()
message = data.get('message')
if not message:
return self.http_status(400, -1, 'message is required')
# 广播消息
broadcast_data = {
'type': 'broadcast',
'message': message,
'timestamp': datetime.datetime.now().isoformat(),
}
await ws_connection_manager.broadcast_to_pipeline(pipeline_uuid, broadcast_data)
return self.success(data={'message': 'Broadcast sent successfully'})
except Exception as e:
return self.http_status(500, -1, f'Internal server error: {str(e)}')
async def _handle_receive(self, connection, websocket_adapter):
"""处理接收消息的任务"""
try:
while connection.is_active:
# 接收消息
message = await quart.websocket.receive()
# 更新活跃时间
await ws_connection_manager.update_activity(connection.connection_id)
try:
data = json.loads(message)
message_type = data.get('type', 'message')
if message_type == 'ping':
# 心跳响应
await connection.send_queue.put(
{'type': 'pong', 'timestamp': datetime.datetime.now().isoformat()}
)
elif message_type == 'message':
# 处理用户消息
logger.debug(f'收到消息: {data} from {connection.connection_id}')
# 处理消息不等待响应响应会通过broadcast异步发送
await websocket_adapter.handle_websocket_message(connection, data)
elif message_type == 'disconnect':
# 客户端主动断开
logger.debug(f'Client disconnected: {connection.connection_id}')
break
else:
logger.warning(f'Unknown message type: {message_type}')
except json.JSONDecodeError:
logger.error(f'Invalid JSON message: {message}')
await connection.send_queue.put({'type': 'error', 'message': 'Invalid JSON format'})
except Exception as e:
logger.error(f'Receive message error: {e}', exc_info=True)
finally:
connection.is_active = False
async def _handle_send(self, connection):
"""处理发送消息的任务"""
try:
while connection.is_active:
# 从队列获取消息
try:
message = await asyncio.wait_for(connection.send_queue.get(), timeout=1.0)
# 发送消息
await quart.websocket.send(json.dumps(message))
except asyncio.TimeoutError:
# 超时继续循环
continue
except Exception as e:
logger.error(f'Send message error: {e}', exc_info=True)
finally:
connection.is_active = False

View File

@@ -156,7 +156,7 @@ class PlatformManager:
bots: list[RuntimeBot]
webchat_proxy_bot: RuntimeBot
websocket_proxy_bot: RuntimeBot
adapter_components: list[engine.Component]
@@ -178,31 +178,29 @@ class PlatformManager:
adapter_dict[component.metadata.name] = component.get_python_component_class()
self.adapter_dict = adapter_dict
webchat_adapter_class = self.adapter_dict['webchat']
# initialize webchat adapter
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
webchat_adapter_inst = webchat_adapter_class(
# initialize websocket adapter
websocket_adapter_class = self.adapter_dict['websocket']
websocket_logger = EventLogger(name='websocket-adapter', ap=self.ap)
websocket_adapter_inst = websocket_adapter_class(
{},
webchat_logger,
websocket_logger,
ap=self.ap,
is_stream=False,
)
self.webchat_proxy_bot = RuntimeBot(
self.websocket_proxy_bot = RuntimeBot(
ap=self.ap,
bot_entity=persistence_bot.Bot(
uuid='webchat-proxy-bot',
name='WebChat',
uuid='websocket-proxy-bot',
name='WebSocket',
description='',
adapter='webchat',
adapter='websocket',
adapter_config={},
enable=True,
),
adapter=webchat_adapter_inst,
logger=webchat_logger,
adapter=websocket_adapter_inst,
logger=websocket_logger,
)
await self.webchat_proxy_bot.initialize()
await self.websocket_proxy_bot.initialize()
await self.load_bots_from_db()
@@ -271,7 +269,7 @@ class PlatformManager:
def get_available_adapters_info(self) -> list[dict]:
return [
component.to_plain_dict() for component in self.adapter_components if component.metadata.name != 'webchat'
component.to_plain_dict() for component in self.adapter_components if component.metadata.name != 'websocket'
]
def get_available_adapter_info_by_name(self, name: str) -> dict | None:
@@ -288,7 +286,7 @@ class PlatformManager:
async def run(self):
# This method will only be called when the application launching
await self.webchat_proxy_bot.run()
await self.websocket_proxy_bot.run()
for bot in self.bots:
if bot.enable:

View File

@@ -1,304 +0,0 @@
import asyncio
import logging
import typing
from datetime import datetime
import pydantic
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ...core import app
logger = logging.getLogger(__name__)
class WebChatMessage(pydantic.BaseModel):
id: int
role: str
content: str
message_chain: list[dict]
timestamp: str
is_final: bool = False
class WebChatSession:
id: str
message_lists: dict[str, list[WebChatMessage]] = {}
resp_waiters: dict[int, asyncio.Future[WebChatMessage]]
resp_queues: dict[int, asyncio.Queue[WebChatMessage]]
def __init__(self, id: str):
self.id = id
self.message_lists = {}
self.resp_waiters = {}
self.resp_queues = {}
def get_message_list(self, pipeline_uuid: str) -> list[WebChatMessage]:
if pipeline_uuid not in self.message_lists:
self.message_lists[pipeline_uuid] = []
return self.message_lists[pipeline_uuid]
class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""WebChat调试适配器用于流水线调试"""
webchat_person_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
webchat_group_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
listeners: dict[
typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = pydantic.Field(default_factory=dict, exclude=True)
is_stream: bool = pydantic.Field(exclude=True)
debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True)
ap: app.Application = pydantic.Field(exclude=True)
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
super().__init__(
config=config,
logger=logger,
**kwargs,
)
self.webchat_person_session = WebChatSession(id='webchatperson')
self.webchat_group_session = WebChatSession(id='webchatgroup')
self.bot_account_id = 'webchatbot'
self.debug_messages = {}
async def send_message(
self,
target_type: str,
target_id: str,
message: platform_message.MessageChain,
) -> dict:
"""发送消息到调试会话"""
session_key = target_id
if session_key not in self.debug_messages:
self.debug_messages[session_key] = []
message_data = {
'id': len(self.debug_messages[session_key]) + 1,
'type': 'bot',
'content': str(message),
'timestamp': datetime.now().isoformat(),
'message_chain': [component.__dict__ for component in message],
}
self.debug_messages[session_key].append(message_data)
await self.logger.info(f'Send message to {session_key}: {message}')
return message_data
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
) -> dict:
"""回复消息"""
message_data = WebChatMessage(
id=-1,
role='assistant',
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=datetime.now().isoformat(),
)
# notify waiter
if isinstance(message_source, platform_events.FriendMessage):
await self.webchat_person_session.resp_queues[message_source.message_chain.message_id].put(message_data)
elif isinstance(message_source, platform_events.GroupMessage):
await self.webchat_group_session.resp_queues[message_source.message_chain.message_id].put(message_data)
return message_data.model_dump()
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message,
message: platform_message.MessageChain,
quote_origin: bool = False,
is_final: bool = False,
) -> dict:
"""回复消息"""
message_data = WebChatMessage(
id=-1,
role='assistant',
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=datetime.now().isoformat(),
)
# notify waiter
session = (
self.webchat_group_session
if isinstance(message_source, platform_events.GroupMessage)
else self.webchat_person_session
)
if message_source.message_chain.message_id not in session.resp_waiters:
# session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue()
queue = session.resp_queues[message_source.message_chain.message_id]
# if isinstance(message_source, platform_events.FriendMessage):
# queue = self.webchat_person_session.resp_queues[message_source.message_chain.message_id]
# elif isinstance(message_source, platform_events.GroupMessage):
# queue = self.webchat_group_session.resp_queues[message_source.message_chain.message_id]
if is_final and bot_message.tool_calls is None:
message_data.is_final = True
# print(message_data)
await queue.put(message_data)
return message_data.model_dump()
async def is_stream_output_supported(self) -> bool:
return self.is_stream
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
):
"""注册事件监听器"""
self.listeners[event_type] = func
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
):
"""取消注册事件监听器"""
del self.listeners[event_type]
async def is_muted(self, group_id: int) -> bool:
return False
async def run_async(self):
"""运行适配器"""
await self.logger.info('WebChat调试适配器已启动')
try:
while True:
await asyncio.sleep(1)
except asyncio.CancelledError:
await self.logger.info('WebChat调试适配器已停止')
raise
async def kill(self):
"""停止适配器"""
await self.logger.info('WebChat调试适配器正在停止')
async def send_webchat_message(
self,
pipeline_uuid: str,
session_type: str,
message_chain_obj: typing.List[dict],
is_stream: bool = False,
) -> dict:
self.is_stream = is_stream
"""发送调试消息到流水线"""
if session_type == 'person':
use_session = self.webchat_person_session
else:
use_session = self.webchat_group_session
message_chain = platform_message.MessageChain.parse_obj(message_chain_obj)
message_id = len(use_session.get_message_list(pipeline_uuid)) + 1
use_session.resp_queues[message_id] = asyncio.Queue()
logger.debug(f'Initialized queue for message_id: {message_id}')
use_session.get_message_list(pipeline_uuid).append(
WebChatMessage(
id=message_id,
role='user',
content=str(message_chain),
message_chain=message_chain_obj,
timestamp=datetime.now().isoformat(),
)
)
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
if session_type == 'person':
sender = platform_entities.Friend(id='webchatperson', nickname='User', remark='User')
event = platform_events.FriendMessage(
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
)
else:
group = platform_entities.Group(
id='webchatgroup', name='Group', permission=platform_entities.Permission.Member
)
sender = platform_entities.GroupMember(
id='webchatperson',
member_name='User',
group=group,
permission=platform_entities.Permission.Member,
)
event = platform_events.GroupMessage(
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
)
self.ap.platform_mgr.webchat_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
# trigger pipeline
if event.__class__ in self.listeners:
await self.listeners[event.__class__](event, self)
if is_stream:
queue = use_session.resp_queues[message_id]
msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1
while True:
resp_message = await queue.get()
resp_message.id = msg_id
if resp_message.is_final:
resp_message.id = msg_id
use_session.get_message_list(pipeline_uuid).append(resp_message)
yield resp_message.model_dump()
break
yield resp_message.model_dump()
use_session.resp_queues.pop(message_id)
else: # non-stream
# set waiter
# waiter = asyncio.Future[WebChatMessage]()
# use_session.resp_waiters[message_id] = waiter
# # waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id))
#
# resp_message = await waiter
#
# resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1
#
# use_session.get_message_list(pipeline_uuid).append(resp_message)
#
# yield resp_message.model_dump()
msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1
queue = use_session.resp_queues[message_id]
resp_message = await queue.get()
use_session.get_message_list(pipeline_uuid).append(resp_message)
resp_message.id = msg_id
resp_message.is_final = True
yield resp_message.model_dump()
def get_webchat_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
"""获取调试消息历史"""
if session_type == 'person':
return [message.model_dump() for message in self.webchat_person_session.get_message_list(pipeline_uuid)]
else:
return [message.model_dump() for message in self.webchat_group_session.get_message_list(pipeline_uuid)]

View File

@@ -1,17 +0,0 @@
apiVersion: v1
kind: MessagePlatformAdapter
metadata:
name: webchat
label:
en_US: "WebChat Debug"
zh_Hans: "网页聊天调试"
description:
en_US: "WebChat adapter for pipeline debugging"
zh_Hans: "用于流水线调试的网页聊天适配器"
icon: ""
spec:
config: []
execution:
python:
path: "webchat.py"
attr: "WebChatAdapter"

View File

@@ -0,0 +1,17 @@
apiVersion: v1
kind: MessagePlatformAdapter
metadata:
name: websocket
label:
en_US: "WebSocket Chat"
zh_Hans: "WebSocket 聊天"
description:
en_US: "WebSocket adapter for bidirectional real-time communication"
zh_Hans: "用于双向实时通信的 WebSocket 适配器"
icon: ""
spec:
config: []
execution:
python:
path: "websocket_adapter.py"
attr: "WebSocketAdapter"

View File

@@ -0,0 +1,402 @@
"""WebSocket适配器 - 支持双向通信的IM系统"""
import asyncio
import logging
import typing
from datetime import datetime
import pydantic
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ...core import app
from .websocket_manager import ws_connection_manager, WebSocketConnection
logger = logging.getLogger(__name__)
class WebSocketMessage(pydantic.BaseModel):
"""WebSocket消息格式"""
id: int
role: str # 'user' or 'assistant'
content: str
message_chain: list[dict]
timestamp: str
is_final: bool = False
connection_id: str = ''
"""发送者连接ID"""
class WebSocketSession:
"""WebSocket会话 - 管理单个会话的消息历史"""
id: str
message_lists: dict[str, list[WebSocketMessage]] = {}
"""消息列表 {pipeline_uuid: [messages]}"""
def __init__(self, id: str):
self.id = id
self.message_lists = {}
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
if pipeline_uuid not in self.message_lists:
self.message_lists[pipeline_uuid] = []
return self.message_lists[pipeline_uuid]
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""WebSocket适配器 - 支持双向实时通信"""
websocket_person_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
websocket_group_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
listeners: dict[
typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = pydantic.Field(default_factory=dict, exclude=True)
ap: app.Application = pydantic.Field(exclude=True)
# 主动推送消息的队列
outbound_message_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True)
"""后端主动推送消息的队列"""
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
super().__init__(
config=config,
logger=logger,
**kwargs,
)
self.websocket_person_session = WebSocketSession(id='websocketperson')
self.websocket_group_session = WebSocketSession(id='websocketgroup')
self.bot_account_id = 'websocketbot'
self.outbound_message_queue = asyncio.Queue()
async def send_message(
self,
target_type: str,
target_id: str,
message: platform_message.MessageChain,
) -> dict:
"""发送消息 - 这里用于主动推送消息到前端"""
message_data = {
'type': 'bot_message',
'target_type': target_type,
'target_id': target_id,
'content': str(message),
'message_chain': [component.__dict__ for component in message],
'timestamp': datetime.now().isoformat(),
}
# 推送到所有相关连接
await self.outbound_message_queue.put(message_data)
await self.logger.info(f'Send message to {target_id}: {message}')
return message_data
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
) -> dict:
"""回复消息 - 非流式"""
# 获取会话和pipeline信息
session = (
self.websocket_group_session
if isinstance(message_source, platform_events.GroupMessage)
else self.websocket_person_session
)
# 从message_source获取pipeline_uuid和connection_id
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
# session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
# 生成新的消息ID
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
message_data = WebSocketMessage(
id=msg_id,
role='assistant',
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=datetime.now().isoformat(),
is_final=True,
)
# 保存到历史记录
session.get_message_list(pipeline_uuid).append(message_data)
# 直接广播到所有该pipeline的连接
await ws_connection_manager.broadcast_to_pipeline(
pipeline_uuid,
{
'type': 'response',
'data': message_data.model_dump(),
},
)
return message_data.model_dump()
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message,
message: platform_message.MessageChain,
quote_origin: bool = False,
is_final: bool = False,
) -> dict:
"""回复消息块 - 流式"""
# 获取会话和pipeline信息
session = (
self.websocket_group_session
if isinstance(message_source, platform_events.GroupMessage)
else self.websocket_person_session
)
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
message_list = session.get_message_list(pipeline_uuid)
# 检查是否是新的流式消息通过bot_message对象判断
# 如果列表为空或者最后一条消息已经is_final=True则创建新消息
if not message_list or message_list[-1].is_final:
# 创建新消息
msg_id = len(message_list) + 1
message_data = WebSocketMessage(
id=msg_id,
role='assistant',
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=datetime.now().isoformat(),
is_final=is_final and bot_message.tool_calls is None,
)
# 只有在is_final时才保存到历史记录
if is_final and bot_message.tool_calls is None:
message_list.append(message_data)
else:
# 更新最后一条消息
msg_id = message_list[-1].id
message_data = WebSocketMessage(
id=msg_id,
role='assistant',
content=str(message),
message_chain=[component.__dict__ for component in message],
timestamp=message_list[-1].timestamp, # 保持原始时间戳
is_final=is_final and bot_message.tool_calls is None,
)
# 如果是final更新历史记录中的最后一条
if is_final and bot_message.tool_calls is None:
message_list[-1] = message_data
# 直接广播到所有该pipeline的连接
await ws_connection_manager.broadcast_to_pipeline(
pipeline_uuid,
{
'type': 'response',
'data': message_data.model_dump(),
},
)
return message_data.model_dump()
async def is_stream_output_supported(self) -> bool:
"""WebSocket始终支持流式输出"""
return True
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
):
"""注册事件监听器"""
self.listeners[event_type] = func
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
):
"""取消注册事件监听器"""
del self.listeners[event_type]
async def is_muted(self, group_id: int) -> bool:
return False
async def run_async(self):
"""运行适配器"""
await self.logger.info('WebSocket适配器已启动')
try:
while True:
# 处理主动推送消息
if not self.outbound_message_queue.empty():
try:
message = await asyncio.wait_for(self.outbound_message_queue.get(), timeout=0.1)
# 广播到所有相关连接
target_id = message.get('target_id', '')
await ws_connection_manager.broadcast_to_pipeline(target_id, message)
except asyncio.TimeoutError:
pass
await asyncio.sleep(0.1)
except asyncio.CancelledError:
await self.logger.info('WebSocket适配器已停止')
raise
async def kill(self):
"""停止适配器"""
await self.logger.info('WebSocket适配器正在停止')
async def _process_image_components(self, message_chain_obj: list):
"""
处理消息链中的图片组件将path转换为base64
Args:
message_chain_obj: 消息链对象列表
"""
import base64
storage_mgr = self.ap.storage_mgr
for component in message_chain_obj:
if component.get('type') == 'Image' and component.get('path'):
try:
# 从storage读取文件
file_content = await storage_mgr.storage_provider.load(component['path'])
# 转换为base64
base64_str = base64.b64encode(file_content).decode('utf-8')
# 添加data URI前缀根据文件扩展名判断MIME类型
file_key = component['path']
if file_key.lower().endswith(('.jpg', '.jpeg')):
mime_type = 'image/jpeg'
elif file_key.lower().endswith('.png'):
mime_type = 'image/png'
elif file_key.lower().endswith('.gif'):
mime_type = 'image/gif'
elif file_key.lower().endswith('.webp'):
mime_type = 'image/webp'
else:
mime_type = 'image/png' # 默认
component['base64'] = f'data:{mime_type};base64,{base64_str}'
await storage_mgr.storage_provider.delete(component['path'])
component['path'] = ''
# 保留path字段用于后端处理前端使用base64显示
except Exception as e:
await self.logger.error(f'加载图片文件失败 {component["path"]}: {e}')
async def handle_websocket_message(
self,
connection: WebSocketConnection,
message_data: dict,
):
"""
处理从WebSocket接收的消息
这个方法只负责接收消息、保存到历史记录、并触发事件处理
不等待任何响应响应消息会通过reply_message/reply_message_chunk直接发送
Args:
connection: WebSocket连接对象
message_data: 消息数据
"""
pipeline_uuid = connection.pipeline_uuid
session_type = connection.session_type
# 选择会话
use_session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session
# 解析消息链
message_chain_obj = message_data.get('message', [])
# 处理图片组件将path转换为base64
await self._process_image_components(message_chain_obj)
message_chain = platform_message.MessageChain.model_validate(message_chain_obj)
# 生成消息ID
message_id = len(use_session.get_message_list(pipeline_uuid)) + 1
# 保存用户消息
user_message = WebSocketMessage(
id=message_id,
role='user',
content=str(message_chain),
message_chain=message_chain_obj,
timestamp=datetime.now().isoformat(),
connection_id=connection.connection_id,
is_final=True, # 用户消息始终是完整的,非流式
)
use_session.get_message_list(pipeline_uuid).append(user_message)
# 广播用户消息到所有连接(包括发送者)
await ws_connection_manager.broadcast_to_pipeline(
pipeline_uuid,
{
'type': 'user_message',
'data': user_message.model_dump(),
},
)
# 添加消息源
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
# 创建事件
if session_type == 'person':
sender = platform_entities.Friend(
id=f'websocket_{connection.connection_id}', nickname='User', remark='User'
)
event = platform_events.FriendMessage(
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
)
else:
group = platform_entities.Group(
id='websocketgroup', name='Group', permission=platform_entities.Permission.Member
)
sender = platform_entities.GroupMember(
id=f'websocket_{connection.connection_id}',
member_name='User',
group=group,
permission=platform_entities.Permission.Member,
)
event = platform_events.GroupMessage(
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
)
# 设置流水线UUID
self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
# 异步触发事件处理(不等待结果)
if event.__class__ in self.listeners:
asyncio.create_task(self.listeners[event.__class__](event, self))
def get_websocket_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
"""获取消息历史"""
if session_type == 'person':
return [message.model_dump() for message in self.websocket_person_session.get_message_list(pipeline_uuid)]
else:
return [message.model_dump() for message in self.websocket_group_session.get_message_list(pipeline_uuid)]
def reset_session(self, pipeline_uuid: str, session_type: str):
"""重置会话"""
if session_type == 'person':
if pipeline_uuid in self.websocket_person_session.message_lists:
self.websocket_person_session.message_lists[pipeline_uuid] = []
else:
if pipeline_uuid in self.websocket_group_session.message_lists:
self.websocket_group_session.message_lists[pipeline_uuid] = []

View File

@@ -0,0 +1,177 @@
"""WebSocket连接管理器 - 管理多个并发WebSocket连接"""
import asyncio
import logging
import typing
import uuid
from datetime import datetime
import pydantic
logger = logging.getLogger(__name__)
class WebSocketConnection(pydantic.BaseModel):
"""单个WebSocket连接"""
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
connection_id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
"""连接唯一ID"""
pipeline_uuid: str
"""关联的流水线UUID"""
session_type: str # 'person' or 'group'
"""会话类型"""
websocket: typing.Any = pydantic.Field(exclude=True)
"""WebSocket连接对象 (quart.websocket)"""
created_at: datetime = pydantic.Field(default_factory=datetime.now)
"""连接创建时间"""
last_active: datetime = pydantic.Field(default_factory=datetime.now)
"""最后活跃时间"""
send_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True)
"""发送消息队列"""
is_active: bool = True
"""连接是否活跃"""
metadata: dict = pydantic.Field(default_factory=dict)
"""连接元数据(可存储额外信息)"""
class WebSocketConnectionManager:
"""WebSocket连接管理器 - 支持多连接并发"""
def __init__(self):
self.connections: dict[str, WebSocketConnection] = {}
"""所有活跃连接 {connection_id: connection}"""
self.pipeline_connections: dict[str, set[str]] = {}
"""流水线到连接的映射 {pipeline_uuid: {connection_id, ...}}"""
self.session_connections: dict[str, set[str]] = {}
"""会话类型到连接的映射 {session_type: {connection_id, ...}}"""
self._lock = asyncio.Lock()
"""线程锁,保护并发访问"""
async def add_connection(
self,
websocket: typing.Any,
pipeline_uuid: str,
session_type: str,
metadata: dict = None,
) -> WebSocketConnection:
"""添加新的WebSocket连接"""
async with self._lock:
connection = WebSocketConnection(
pipeline_uuid=pipeline_uuid,
session_type=session_type,
websocket=websocket,
metadata=metadata or {},
)
self.connections[connection.connection_id] = connection
# 更新流水线映射
if pipeline_uuid not in self.pipeline_connections:
self.pipeline_connections[pipeline_uuid] = set()
self.pipeline_connections[pipeline_uuid].add(connection.connection_id)
# 更新会话类型映射
if session_type not in self.session_connections:
self.session_connections[session_type] = set()
self.session_connections[session_type].add(connection.connection_id)
logger.debug(
f'WebSocket connection established: {connection.connection_id} '
f'(pipeline={pipeline_uuid}, session_type={session_type})'
)
return connection
async def remove_connection(self, connection_id: str):
"""移除WebSocket连接"""
async with self._lock:
if connection_id not in self.connections:
return
connection = self.connections[connection_id]
connection.is_active = False
# 从流水线映射中移除
if connection.pipeline_uuid in self.pipeline_connections:
self.pipeline_connections[connection.pipeline_uuid].discard(connection_id)
if not self.pipeline_connections[connection.pipeline_uuid]:
del self.pipeline_connections[connection.pipeline_uuid]
# 从会话类型映射中移除
if connection.session_type in self.session_connections:
self.session_connections[connection.session_type].discard(connection_id)
if not self.session_connections[connection.session_type]:
del self.session_connections[connection.session_type]
del self.connections[connection_id]
logger.debug(f'WebSocket connection disconnected: {connection_id}')
async def get_connection(self, connection_id: str) -> typing.Optional[WebSocketConnection]:
"""获取指定连接"""
return self.connections.get(connection_id)
async def get_connections_by_pipeline(self, pipeline_uuid: str) -> list[WebSocketConnection]:
"""获取指定流水线的所有连接"""
connection_ids = self.pipeline_connections.get(pipeline_uuid, set())
return [self.connections[cid] for cid in connection_ids if cid in self.connections]
async def get_connections_by_session_type(self, session_type: str) -> list[WebSocketConnection]:
"""获取指定会话类型的所有连接"""
connection_ids = self.session_connections.get(session_type, set())
return [self.connections[cid] for cid in connection_ids if cid in self.connections]
async def broadcast_to_pipeline(self, pipeline_uuid: str, message: dict):
"""向指定流水线的所有连接广播消息"""
connections = await self.get_connections_by_pipeline(pipeline_uuid)
tasks = []
for conn in connections:
tasks.append(self.send_to_connection(conn.connection_id, message))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def send_to_connection(self, connection_id: str, message: dict):
"""向指定连接发送消息"""
connection = await self.get_connection(connection_id)
if not connection or not connection.is_active:
logger.warning(f'Attempt to send message to invalid connection: {connection_id}')
return
try:
await connection.send_queue.put(message)
connection.last_active = datetime.now()
except Exception as e:
logger.error(f'Failed to send message to connection {connection_id}: {e}')
await self.remove_connection(connection_id)
async def update_activity(self, connection_id: str):
"""更新连接活跃时间"""
connection = await self.get_connection(connection_id)
if connection:
connection.last_active = datetime.now()
def get_stats(self) -> dict:
"""获取连接统计信息"""
return {
'total_connections': len(self.connections),
'pipelines': len(self.pipeline_connections),
'connections_by_pipeline': {k: len(v) for k, v in self.pipeline_connections.items()},
'connections_by_session_type': {k: len(v) for k, v in self.session_connections.items()},
}
# 全局连接管理器实例
ws_connection_manager = WebSocketConnectionManager()