mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 07:54:19 +00:00
refactor: switch webchat from sse to websocket (#1808)
* refactor: switch webchat from sse to websocket * perf: image preview dialog * chore: remove console.log
This commit is contained in:
committed by
GitHub
parent
348620ac0a
commit
d09b823c49
@@ -156,7 +156,7 @@ class PlatformManager:
|
||||
|
||||
bots: list[RuntimeBot]
|
||||
|
||||
webchat_proxy_bot: RuntimeBot
|
||||
websocket_proxy_bot: RuntimeBot
|
||||
|
||||
adapter_components: list[engine.Component]
|
||||
|
||||
@@ -178,31 +178,29 @@ class PlatformManager:
|
||||
adapter_dict[component.metadata.name] = component.get_python_component_class()
|
||||
self.adapter_dict = adapter_dict
|
||||
|
||||
webchat_adapter_class = self.adapter_dict['webchat']
|
||||
|
||||
# initialize webchat adapter
|
||||
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
|
||||
webchat_adapter_inst = webchat_adapter_class(
|
||||
# initialize websocket adapter
|
||||
websocket_adapter_class = self.adapter_dict['websocket']
|
||||
websocket_logger = EventLogger(name='websocket-adapter', ap=self.ap)
|
||||
websocket_adapter_inst = websocket_adapter_class(
|
||||
{},
|
||||
webchat_logger,
|
||||
websocket_logger,
|
||||
ap=self.ap,
|
||||
is_stream=False,
|
||||
)
|
||||
|
||||
self.webchat_proxy_bot = RuntimeBot(
|
||||
self.websocket_proxy_bot = RuntimeBot(
|
||||
ap=self.ap,
|
||||
bot_entity=persistence_bot.Bot(
|
||||
uuid='webchat-proxy-bot',
|
||||
name='WebChat',
|
||||
uuid='websocket-proxy-bot',
|
||||
name='WebSocket',
|
||||
description='',
|
||||
adapter='webchat',
|
||||
adapter='websocket',
|
||||
adapter_config={},
|
||||
enable=True,
|
||||
),
|
||||
adapter=webchat_adapter_inst,
|
||||
logger=webchat_logger,
|
||||
adapter=websocket_adapter_inst,
|
||||
logger=websocket_logger,
|
||||
)
|
||||
await self.webchat_proxy_bot.initialize()
|
||||
await self.websocket_proxy_bot.initialize()
|
||||
|
||||
await self.load_bots_from_db()
|
||||
|
||||
@@ -271,7 +269,7 @@ class PlatformManager:
|
||||
|
||||
def get_available_adapters_info(self) -> list[dict]:
|
||||
return [
|
||||
component.to_plain_dict() for component in self.adapter_components if component.metadata.name != 'webchat'
|
||||
component.to_plain_dict() for component in self.adapter_components if component.metadata.name != 'websocket'
|
||||
]
|
||||
|
||||
def get_available_adapter_info_by_name(self, name: str) -> dict | None:
|
||||
@@ -288,7 +286,7 @@ class PlatformManager:
|
||||
|
||||
async def run(self):
|
||||
# This method will only be called when the application launching
|
||||
await self.webchat_proxy_bot.run()
|
||||
await self.websocket_proxy_bot.run()
|
||||
|
||||
for bot in self.bots:
|
||||
if bot.enable:
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime
|
||||
|
||||
import pydantic
|
||||
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
from ...core import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebChatMessage(pydantic.BaseModel):
|
||||
id: int
|
||||
role: str
|
||||
content: str
|
||||
message_chain: list[dict]
|
||||
timestamp: str
|
||||
is_final: bool = False
|
||||
|
||||
|
||||
class WebChatSession:
|
||||
id: str
|
||||
message_lists: dict[str, list[WebChatMessage]] = {}
|
||||
resp_waiters: dict[int, asyncio.Future[WebChatMessage]]
|
||||
resp_queues: dict[int, asyncio.Queue[WebChatMessage]]
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
self.message_lists = {}
|
||||
self.resp_waiters = {}
|
||||
self.resp_queues = {}
|
||||
|
||||
def get_message_list(self, pipeline_uuid: str) -> list[WebChatMessage]:
|
||||
if pipeline_uuid not in self.message_lists:
|
||||
self.message_lists[pipeline_uuid] = []
|
||||
|
||||
return self.message_lists[pipeline_uuid]
|
||||
|
||||
|
||||
class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""WebChat调试适配器,用于流水线调试"""
|
||||
|
||||
webchat_person_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
|
||||
webchat_group_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
|
||||
|
||||
listeners: dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||
] = pydantic.Field(default_factory=dict, exclude=True)
|
||||
|
||||
is_stream: bool = pydantic.Field(exclude=True)
|
||||
debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True)
|
||||
|
||||
ap: app.Application = pydantic.Field(exclude=True)
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.webchat_person_session = WebChatSession(id='webchatperson')
|
||||
self.webchat_group_session = WebChatSession(id='webchatgroup')
|
||||
|
||||
self.bot_account_id = 'webchatbot'
|
||||
|
||||
self.debug_messages = {}
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: platform_message.MessageChain,
|
||||
) -> dict:
|
||||
"""发送消息到调试会话"""
|
||||
session_key = target_id
|
||||
|
||||
if session_key not in self.debug_messages:
|
||||
self.debug_messages[session_key] = []
|
||||
|
||||
message_data = {
|
||||
'id': len(self.debug_messages[session_key]) + 1,
|
||||
'type': 'bot',
|
||||
'content': str(message),
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'message_chain': [component.__dict__ for component in message],
|
||||
}
|
||||
|
||||
self.debug_messages[session_key].append(message_data)
|
||||
|
||||
await self.logger.info(f'Send message to {session_key}: {message}')
|
||||
|
||||
return message_data
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
) -> dict:
|
||||
"""回复消息"""
|
||||
message_data = WebChatMessage(
|
||||
id=-1,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
# notify waiter
|
||||
if isinstance(message_source, platform_events.FriendMessage):
|
||||
await self.webchat_person_session.resp_queues[message_source.message_chain.message_id].put(message_data)
|
||||
elif isinstance(message_source, platform_events.GroupMessage):
|
||||
await self.webchat_group_session.resp_queues[message_source.message_chain.message_id].put(message_data)
|
||||
|
||||
return message_data.model_dump()
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
) -> dict:
|
||||
"""回复消息"""
|
||||
message_data = WebChatMessage(
|
||||
id=-1,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
# notify waiter
|
||||
session = (
|
||||
self.webchat_group_session
|
||||
if isinstance(message_source, platform_events.GroupMessage)
|
||||
else self.webchat_person_session
|
||||
)
|
||||
if message_source.message_chain.message_id not in session.resp_waiters:
|
||||
# session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue()
|
||||
queue = session.resp_queues[message_source.message_chain.message_id]
|
||||
|
||||
# if isinstance(message_source, platform_events.FriendMessage):
|
||||
# queue = self.webchat_person_session.resp_queues[message_source.message_chain.message_id]
|
||||
# elif isinstance(message_source, platform_events.GroupMessage):
|
||||
# queue = self.webchat_group_session.resp_queues[message_source.message_chain.message_id]
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_data.is_final = True
|
||||
# print(message_data)
|
||||
await queue.put(message_data)
|
||||
|
||||
return message_data.model_dump()
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
return self.is_stream
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
func: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||||
],
|
||||
):
|
||||
"""注册事件监听器"""
|
||||
self.listeners[event_type] = func
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
func: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||||
],
|
||||
):
|
||||
"""取消注册事件监听器"""
|
||||
del self.listeners[event_type]
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
async def run_async(self):
|
||||
"""运行适配器"""
|
||||
await self.logger.info('WebChat调试适配器已启动')
|
||||
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
await self.logger.info('WebChat调试适配器已停止')
|
||||
raise
|
||||
|
||||
async def kill(self):
|
||||
"""停止适配器"""
|
||||
await self.logger.info('WebChat调试适配器正在停止')
|
||||
|
||||
async def send_webchat_message(
|
||||
self,
|
||||
pipeline_uuid: str,
|
||||
session_type: str,
|
||||
message_chain_obj: typing.List[dict],
|
||||
is_stream: bool = False,
|
||||
) -> dict:
|
||||
self.is_stream = is_stream
|
||||
"""发送调试消息到流水线"""
|
||||
if session_type == 'person':
|
||||
use_session = self.webchat_person_session
|
||||
else:
|
||||
use_session = self.webchat_group_session
|
||||
|
||||
message_chain = platform_message.MessageChain.parse_obj(message_chain_obj)
|
||||
|
||||
message_id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
|
||||
use_session.resp_queues[message_id] = asyncio.Queue()
|
||||
logger.debug(f'Initialized queue for message_id: {message_id}')
|
||||
|
||||
use_session.get_message_list(pipeline_uuid).append(
|
||||
WebChatMessage(
|
||||
id=message_id,
|
||||
role='user',
|
||||
content=str(message_chain),
|
||||
message_chain=message_chain_obj,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
)
|
||||
|
||||
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
|
||||
|
||||
if session_type == 'person':
|
||||
sender = platform_entities.Friend(id='webchatperson', nickname='User', remark='User')
|
||||
event = platform_events.FriendMessage(
|
||||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||||
)
|
||||
else:
|
||||
group = platform_entities.Group(
|
||||
id='webchatgroup', name='Group', permission=platform_entities.Permission.Member
|
||||
)
|
||||
sender = platform_entities.GroupMember(
|
||||
id='webchatperson',
|
||||
member_name='User',
|
||||
group=group,
|
||||
permission=platform_entities.Permission.Member,
|
||||
)
|
||||
event = platform_events.GroupMessage(
|
||||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||||
)
|
||||
|
||||
self.ap.platform_mgr.webchat_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
|
||||
|
||||
# trigger pipeline
|
||||
if event.__class__ in self.listeners:
|
||||
await self.listeners[event.__class__](event, self)
|
||||
|
||||
if is_stream:
|
||||
queue = use_session.resp_queues[message_id]
|
||||
msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
while True:
|
||||
resp_message = await queue.get()
|
||||
resp_message.id = msg_id
|
||||
if resp_message.is_final:
|
||||
resp_message.id = msg_id
|
||||
use_session.get_message_list(pipeline_uuid).append(resp_message)
|
||||
yield resp_message.model_dump()
|
||||
break
|
||||
yield resp_message.model_dump()
|
||||
use_session.resp_queues.pop(message_id)
|
||||
|
||||
else: # non-stream
|
||||
# set waiter
|
||||
# waiter = asyncio.Future[WebChatMessage]()
|
||||
# use_session.resp_waiters[message_id] = waiter
|
||||
# # waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id))
|
||||
#
|
||||
# resp_message = await waiter
|
||||
#
|
||||
# resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
#
|
||||
# use_session.get_message_list(pipeline_uuid).append(resp_message)
|
||||
#
|
||||
# yield resp_message.model_dump()
|
||||
msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
|
||||
queue = use_session.resp_queues[message_id]
|
||||
resp_message = await queue.get()
|
||||
use_session.get_message_list(pipeline_uuid).append(resp_message)
|
||||
resp_message.id = msg_id
|
||||
resp_message.is_final = True
|
||||
|
||||
yield resp_message.model_dump()
|
||||
|
||||
def get_webchat_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
|
||||
"""获取调试消息历史"""
|
||||
if session_type == 'person':
|
||||
return [message.model_dump() for message in self.webchat_person_session.get_message_list(pipeline_uuid)]
|
||||
else:
|
||||
return [message.model_dump() for message in self.webchat_group_session.get_message_list(pipeline_uuid)]
|
||||
@@ -1,17 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: MessagePlatformAdapter
|
||||
metadata:
|
||||
name: webchat
|
||||
label:
|
||||
en_US: "WebChat Debug"
|
||||
zh_Hans: "网页聊天调试"
|
||||
description:
|
||||
en_US: "WebChat adapter for pipeline debugging"
|
||||
zh_Hans: "用于流水线调试的网页聊天适配器"
|
||||
icon: ""
|
||||
spec:
|
||||
config: []
|
||||
execution:
|
||||
python:
|
||||
path: "webchat.py"
|
||||
attr: "WebChatAdapter"
|
||||
@@ -0,0 +1,17 @@
|
||||
apiVersion: v1
|
||||
kind: MessagePlatformAdapter
|
||||
metadata:
|
||||
name: websocket
|
||||
label:
|
||||
en_US: "WebSocket Chat"
|
||||
zh_Hans: "WebSocket 聊天"
|
||||
description:
|
||||
en_US: "WebSocket adapter for bidirectional real-time communication"
|
||||
zh_Hans: "用于双向实时通信的 WebSocket 适配器"
|
||||
icon: ""
|
||||
spec:
|
||||
config: []
|
||||
execution:
|
||||
python:
|
||||
path: "websocket_adapter.py"
|
||||
attr: "WebSocketAdapter"
|
||||
@@ -0,0 +1,402 @@
|
||||
"""WebSocket适配器 - 支持双向通信的IM系统"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime
|
||||
|
||||
import pydantic
|
||||
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
from ...core import app
|
||||
from .websocket_manager import ws_connection_manager, WebSocketConnection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketMessage(pydantic.BaseModel):
|
||||
"""WebSocket消息格式"""
|
||||
|
||||
id: int
|
||||
role: str # 'user' or 'assistant'
|
||||
content: str
|
||||
message_chain: list[dict]
|
||||
timestamp: str
|
||||
is_final: bool = False
|
||||
connection_id: str = ''
|
||||
"""发送者连接ID"""
|
||||
|
||||
|
||||
class WebSocketSession:
|
||||
"""WebSocket会话 - 管理单个会话的消息历史"""
|
||||
|
||||
id: str
|
||||
message_lists: dict[str, list[WebSocketMessage]] = {}
|
||||
"""消息列表 {pipeline_uuid: [messages]}"""
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
self.message_lists = {}
|
||||
|
||||
def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]:
|
||||
if pipeline_uuid not in self.message_lists:
|
||||
self.message_lists[pipeline_uuid] = []
|
||||
return self.message_lists[pipeline_uuid]
|
||||
|
||||
|
||||
class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""WebSocket适配器 - 支持双向实时通信"""
|
||||
|
||||
websocket_person_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
|
||||
websocket_group_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession)
|
||||
|
||||
listeners: dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||
] = pydantic.Field(default_factory=dict, exclude=True)
|
||||
|
||||
ap: app.Application = pydantic.Field(exclude=True)
|
||||
|
||||
# 主动推送消息的队列
|
||||
outbound_message_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True)
|
||||
"""后端主动推送消息的队列"""
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.websocket_person_session = WebSocketSession(id='websocketperson')
|
||||
self.websocket_group_session = WebSocketSession(id='websocketgroup')
|
||||
|
||||
self.bot_account_id = 'websocketbot'
|
||||
self.outbound_message_queue = asyncio.Queue()
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: platform_message.MessageChain,
|
||||
) -> dict:
|
||||
"""发送消息 - 这里用于主动推送消息到前端"""
|
||||
message_data = {
|
||||
'type': 'bot_message',
|
||||
'target_type': target_type,
|
||||
'target_id': target_id,
|
||||
'content': str(message),
|
||||
'message_chain': [component.__dict__ for component in message],
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# 推送到所有相关连接
|
||||
await self.outbound_message_queue.put(message_data)
|
||||
|
||||
await self.logger.info(f'Send message to {target_id}: {message}')
|
||||
|
||||
return message_data
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
) -> dict:
|
||||
"""回复消息 - 非流式"""
|
||||
# 获取会话和pipeline信息
|
||||
session = (
|
||||
self.websocket_group_session
|
||||
if isinstance(message_source, platform_events.GroupMessage)
|
||||
else self.websocket_person_session
|
||||
)
|
||||
|
||||
# 从message_source获取pipeline_uuid和connection_id
|
||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
# session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person'
|
||||
|
||||
# 生成新的消息ID
|
||||
msg_id = len(session.get_message_list(pipeline_uuid)) + 1
|
||||
|
||||
message_data = WebSocketMessage(
|
||||
id=msg_id,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# 保存到历史记录
|
||||
session.get_message_list(pipeline_uuid).append(message_data)
|
||||
|
||||
# 直接广播到所有该pipeline的连接
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
pipeline_uuid,
|
||||
{
|
||||
'type': 'response',
|
||||
'data': message_data.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
return message_data.model_dump()
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
) -> dict:
|
||||
"""回复消息块 - 流式"""
|
||||
# 获取会话和pipeline信息
|
||||
session = (
|
||||
self.websocket_group_session
|
||||
if isinstance(message_source, platform_events.GroupMessage)
|
||||
else self.websocket_person_session
|
||||
)
|
||||
|
||||
pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
message_list = session.get_message_list(pipeline_uuid)
|
||||
|
||||
# 检查是否是新的流式消息(通过bot_message对象判断)
|
||||
# 如果列表为空,或者最后一条消息已经is_final=True,则创建新消息
|
||||
if not message_list or message_list[-1].is_final:
|
||||
# 创建新消息
|
||||
msg_id = len(message_list) + 1
|
||||
message_data = WebSocketMessage(
|
||||
id=msg_id,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
is_final=is_final and bot_message.tool_calls is None,
|
||||
)
|
||||
|
||||
# 只有在is_final时才保存到历史记录
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_list.append(message_data)
|
||||
else:
|
||||
# 更新最后一条消息
|
||||
msg_id = message_list[-1].id
|
||||
message_data = WebSocketMessage(
|
||||
id=msg_id,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=message_list[-1].timestamp, # 保持原始时间戳
|
||||
is_final=is_final and bot_message.tool_calls is None,
|
||||
)
|
||||
|
||||
# 如果是final,更新历史记录中的最后一条
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_list[-1] = message_data
|
||||
|
||||
# 直接广播到所有该pipeline的连接
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
pipeline_uuid,
|
||||
{
|
||||
'type': 'response',
|
||||
'data': message_data.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
return message_data.model_dump()
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
"""WebSocket始终支持流式输出"""
|
||||
return True
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
func: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||||
],
|
||||
):
|
||||
"""注册事件监听器"""
|
||||
self.listeners[event_type] = func
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
func: typing.Callable[
|
||||
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
|
||||
],
|
||||
):
|
||||
"""取消注册事件监听器"""
|
||||
del self.listeners[event_type]
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
async def run_async(self):
|
||||
"""运行适配器"""
|
||||
await self.logger.info('WebSocket适配器已启动')
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 处理主动推送消息
|
||||
if not self.outbound_message_queue.empty():
|
||||
try:
|
||||
message = await asyncio.wait_for(self.outbound_message_queue.get(), timeout=0.1)
|
||||
# 广播到所有相关连接
|
||||
target_id = message.get('target_id', '')
|
||||
await ws_connection_manager.broadcast_to_pipeline(target_id, message)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
await self.logger.info('WebSocket适配器已停止')
|
||||
raise
|
||||
|
||||
async def kill(self):
|
||||
"""停止适配器"""
|
||||
await self.logger.info('WebSocket适配器正在停止')
|
||||
|
||||
async def _process_image_components(self, message_chain_obj: list):
|
||||
"""
|
||||
处理消息链中的图片组件,将path转换为base64
|
||||
|
||||
Args:
|
||||
message_chain_obj: 消息链对象列表
|
||||
"""
|
||||
import base64
|
||||
|
||||
storage_mgr = self.ap.storage_mgr
|
||||
|
||||
for component in message_chain_obj:
|
||||
if component.get('type') == 'Image' and component.get('path'):
|
||||
try:
|
||||
# 从storage读取文件
|
||||
file_content = await storage_mgr.storage_provider.load(component['path'])
|
||||
|
||||
# 转换为base64
|
||||
base64_str = base64.b64encode(file_content).decode('utf-8')
|
||||
|
||||
# 添加data URI前缀(根据文件扩展名判断MIME类型)
|
||||
file_key = component['path']
|
||||
if file_key.lower().endswith(('.jpg', '.jpeg')):
|
||||
mime_type = 'image/jpeg'
|
||||
elif file_key.lower().endswith('.png'):
|
||||
mime_type = 'image/png'
|
||||
elif file_key.lower().endswith('.gif'):
|
||||
mime_type = 'image/gif'
|
||||
elif file_key.lower().endswith('.webp'):
|
||||
mime_type = 'image/webp'
|
||||
else:
|
||||
mime_type = 'image/png' # 默认
|
||||
|
||||
component['base64'] = f'data:{mime_type};base64,{base64_str}'
|
||||
await storage_mgr.storage_provider.delete(component['path'])
|
||||
component['path'] = ''
|
||||
# 保留path字段用于后端处理,前端使用base64显示
|
||||
except Exception as e:
|
||||
await self.logger.error(f'加载图片文件失败 {component["path"]}: {e}')
|
||||
|
||||
async def handle_websocket_message(
|
||||
self,
|
||||
connection: WebSocketConnection,
|
||||
message_data: dict,
|
||||
):
|
||||
"""
|
||||
处理从WebSocket接收的消息
|
||||
|
||||
这个方法只负责接收消息、保存到历史记录、并触发事件处理
|
||||
不等待任何响应,响应消息会通过reply_message/reply_message_chunk直接发送
|
||||
|
||||
Args:
|
||||
connection: WebSocket连接对象
|
||||
message_data: 消息数据
|
||||
"""
|
||||
pipeline_uuid = connection.pipeline_uuid
|
||||
session_type = connection.session_type
|
||||
|
||||
# 选择会话
|
||||
use_session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session
|
||||
|
||||
# 解析消息链
|
||||
message_chain_obj = message_data.get('message', [])
|
||||
|
||||
# 处理图片组件:将path转换为base64
|
||||
await self._process_image_components(message_chain_obj)
|
||||
|
||||
message_chain = platform_message.MessageChain.model_validate(message_chain_obj)
|
||||
|
||||
# 生成消息ID
|
||||
message_id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
|
||||
# 保存用户消息
|
||||
user_message = WebSocketMessage(
|
||||
id=message_id,
|
||||
role='user',
|
||||
content=str(message_chain),
|
||||
message_chain=message_chain_obj,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
connection_id=connection.connection_id,
|
||||
is_final=True, # 用户消息始终是完整的,非流式
|
||||
)
|
||||
use_session.get_message_list(pipeline_uuid).append(user_message)
|
||||
|
||||
# 广播用户消息到所有连接(包括发送者)
|
||||
await ws_connection_manager.broadcast_to_pipeline(
|
||||
pipeline_uuid,
|
||||
{
|
||||
'type': 'user_message',
|
||||
'data': user_message.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
# 添加消息源
|
||||
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
|
||||
|
||||
# 创建事件
|
||||
if session_type == 'person':
|
||||
sender = platform_entities.Friend(
|
||||
id=f'websocket_{connection.connection_id}', nickname='User', remark='User'
|
||||
)
|
||||
event = platform_events.FriendMessage(
|
||||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||||
)
|
||||
else:
|
||||
group = platform_entities.Group(
|
||||
id='websocketgroup', name='Group', permission=platform_entities.Permission.Member
|
||||
)
|
||||
sender = platform_entities.GroupMember(
|
||||
id=f'websocket_{connection.connection_id}',
|
||||
member_name='User',
|
||||
group=group,
|
||||
permission=platform_entities.Permission.Member,
|
||||
)
|
||||
event = platform_events.GroupMessage(
|
||||
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
|
||||
)
|
||||
|
||||
# 设置流水线UUID
|
||||
self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
|
||||
|
||||
# 异步触发事件处理(不等待结果)
|
||||
if event.__class__ in self.listeners:
|
||||
asyncio.create_task(self.listeners[event.__class__](event, self))
|
||||
|
||||
def get_websocket_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
|
||||
"""获取消息历史"""
|
||||
if session_type == 'person':
|
||||
return [message.model_dump() for message in self.websocket_person_session.get_message_list(pipeline_uuid)]
|
||||
else:
|
||||
return [message.model_dump() for message in self.websocket_group_session.get_message_list(pipeline_uuid)]
|
||||
|
||||
def reset_session(self, pipeline_uuid: str, session_type: str):
|
||||
"""重置会话"""
|
||||
if session_type == 'person':
|
||||
if pipeline_uuid in self.websocket_person_session.message_lists:
|
||||
self.websocket_person_session.message_lists[pipeline_uuid] = []
|
||||
else:
|
||||
if pipeline_uuid in self.websocket_group_session.message_lists:
|
||||
self.websocket_group_session.message_lists[pipeline_uuid] = []
|
||||
@@ -0,0 +1,177 @@
|
||||
"""WebSocket连接管理器 - 管理多个并发WebSocket连接"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pydantic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketConnection(pydantic.BaseModel):
|
||||
"""单个WebSocket连接"""
|
||||
|
||||
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
connection_id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""连接唯一ID"""
|
||||
|
||||
pipeline_uuid: str
|
||||
"""关联的流水线UUID"""
|
||||
|
||||
session_type: str # 'person' or 'group'
|
||||
"""会话类型"""
|
||||
|
||||
websocket: typing.Any = pydantic.Field(exclude=True)
|
||||
"""WebSocket连接对象 (quart.websocket)"""
|
||||
|
||||
created_at: datetime = pydantic.Field(default_factory=datetime.now)
|
||||
"""连接创建时间"""
|
||||
|
||||
last_active: datetime = pydantic.Field(default_factory=datetime.now)
|
||||
"""最后活跃时间"""
|
||||
|
||||
send_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True)
|
||||
"""发送消息队列"""
|
||||
|
||||
is_active: bool = True
|
||||
"""连接是否活跃"""
|
||||
|
||||
metadata: dict = pydantic.Field(default_factory=dict)
|
||||
"""连接元数据(可存储额外信息)"""
|
||||
|
||||
|
||||
class WebSocketConnectionManager:
|
||||
"""WebSocket连接管理器 - 支持多连接并发"""
|
||||
|
||||
def __init__(self):
|
||||
self.connections: dict[str, WebSocketConnection] = {}
|
||||
"""所有活跃连接 {connection_id: connection}"""
|
||||
|
||||
self.pipeline_connections: dict[str, set[str]] = {}
|
||||
"""流水线到连接的映射 {pipeline_uuid: {connection_id, ...}}"""
|
||||
|
||||
self.session_connections: dict[str, set[str]] = {}
|
||||
"""会话类型到连接的映射 {session_type: {connection_id, ...}}"""
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
"""线程锁,保护并发访问"""
|
||||
|
||||
async def add_connection(
|
||||
self,
|
||||
websocket: typing.Any,
|
||||
pipeline_uuid: str,
|
||||
session_type: str,
|
||||
metadata: dict = None,
|
||||
) -> WebSocketConnection:
|
||||
"""添加新的WebSocket连接"""
|
||||
async with self._lock:
|
||||
connection = WebSocketConnection(
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
session_type=session_type,
|
||||
websocket=websocket,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self.connections[connection.connection_id] = connection
|
||||
|
||||
# 更新流水线映射
|
||||
if pipeline_uuid not in self.pipeline_connections:
|
||||
self.pipeline_connections[pipeline_uuid] = set()
|
||||
self.pipeline_connections[pipeline_uuid].add(connection.connection_id)
|
||||
|
||||
# 更新会话类型映射
|
||||
if session_type not in self.session_connections:
|
||||
self.session_connections[session_type] = set()
|
||||
self.session_connections[session_type].add(connection.connection_id)
|
||||
|
||||
logger.debug(
|
||||
f'WebSocket connection established: {connection.connection_id} '
|
||||
f'(pipeline={pipeline_uuid}, session_type={session_type})'
|
||||
)
|
||||
|
||||
return connection
|
||||
|
||||
async def remove_connection(self, connection_id: str):
|
||||
"""移除WebSocket连接"""
|
||||
async with self._lock:
|
||||
if connection_id not in self.connections:
|
||||
return
|
||||
|
||||
connection = self.connections[connection_id]
|
||||
connection.is_active = False
|
||||
|
||||
# 从流水线映射中移除
|
||||
if connection.pipeline_uuid in self.pipeline_connections:
|
||||
self.pipeline_connections[connection.pipeline_uuid].discard(connection_id)
|
||||
if not self.pipeline_connections[connection.pipeline_uuid]:
|
||||
del self.pipeline_connections[connection.pipeline_uuid]
|
||||
|
||||
# 从会话类型映射中移除
|
||||
if connection.session_type in self.session_connections:
|
||||
self.session_connections[connection.session_type].discard(connection_id)
|
||||
if not self.session_connections[connection.session_type]:
|
||||
del self.session_connections[connection.session_type]
|
||||
|
||||
del self.connections[connection_id]
|
||||
|
||||
logger.debug(f'WebSocket connection disconnected: {connection_id}')
|
||||
|
||||
async def get_connection(self, connection_id: str) -> typing.Optional[WebSocketConnection]:
|
||||
"""获取指定连接"""
|
||||
return self.connections.get(connection_id)
|
||||
|
||||
async def get_connections_by_pipeline(self, pipeline_uuid: str) -> list[WebSocketConnection]:
|
||||
"""获取指定流水线的所有连接"""
|
||||
connection_ids = self.pipeline_connections.get(pipeline_uuid, set())
|
||||
return [self.connections[cid] for cid in connection_ids if cid in self.connections]
|
||||
|
||||
async def get_connections_by_session_type(self, session_type: str) -> list[WebSocketConnection]:
|
||||
"""获取指定会话类型的所有连接"""
|
||||
connection_ids = self.session_connections.get(session_type, set())
|
||||
return [self.connections[cid] for cid in connection_ids if cid in self.connections]
|
||||
|
||||
async def broadcast_to_pipeline(self, pipeline_uuid: str, message: dict):
|
||||
"""向指定流水线的所有连接广播消息"""
|
||||
connections = await self.get_connections_by_pipeline(pipeline_uuid)
|
||||
tasks = []
|
||||
for conn in connections:
|
||||
tasks.append(self.send_to_connection(conn.connection_id, message))
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def send_to_connection(self, connection_id: str, message: dict):
|
||||
"""向指定连接发送消息"""
|
||||
connection = await self.get_connection(connection_id)
|
||||
if not connection or not connection.is_active:
|
||||
logger.warning(f'Attempt to send message to invalid connection: {connection_id}')
|
||||
return
|
||||
|
||||
try:
|
||||
await connection.send_queue.put(message)
|
||||
connection.last_active = datetime.now()
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to send message to connection {connection_id}: {e}')
|
||||
await self.remove_connection(connection_id)
|
||||
|
||||
async def update_activity(self, connection_id: str):
|
||||
"""更新连接活跃时间"""
|
||||
connection = await self.get_connection(connection_id)
|
||||
if connection:
|
||||
connection.last_active = datetime.now()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取连接统计信息"""
|
||||
return {
|
||||
'total_connections': len(self.connections),
|
||||
'pipelines': len(self.pipeline_connections),
|
||||
'connections_by_pipeline': {k: len(v) for k, v in self.pipeline_connections.items()},
|
||||
'connections_by_session_type': {k: len(v) for k, v in self.session_connections.items()},
|
||||
}
|
||||
|
||||
|
||||
# 全局连接管理器实例
|
||||
ws_connection_manager = WebSocketConnectionManager()
|
||||
Reference in New Issue
Block a user