diff --git a/MIGRATION_SUMMARY.md b/MIGRATION_SUMMARY.md new file mode 100644 index 00000000..c038d04a --- /dev/null +++ b/MIGRATION_SUMMARY.md @@ -0,0 +1,412 @@ +# WebChat 到 WebSocket 迁移总结 + +## 概述 + +已完全移除旧的基于SSE的WebChat系统,并替换为基于WebSocket的双向实时通信系统。这是一个内置在LangBot中的完整IM系统,支持流式输出。 + +## 已删除的文件 + +### 后端 +- ❌ `src/langbot/pkg/api/http/controller/groups/pipelines/webchat.py` - 旧的SSE路由 +- ❌ `src/langbot/pkg/platform/sources/webchat.py` - 旧的WebChat适配器 +- ❌ `src/langbot/pkg/platform/sources/webchat.yaml` - 旧的配置文件 + +### 前端 +- ❌ BackendClient中所有SSE相关代码已完全移除 +- ❌ DebugDialog中所有SSE相关逻辑已完全替换 + +## 新增的文件 + +### 后端核心文件 + +**1. WebSocket连接管理器** +``` +src/langbot/pkg/platform/sources/websocket_manager.py +``` +- 管理所有并发WebSocket连接 +- 线程安全的连接池 +- 按流水线、会话类型分组 +- 广播和单播消息功能 +- 连接统计和监控 + +**2. WebSocket适配器** +``` +src/langbot/pkg/platform/sources/websocket_adapter.py +``` +- 实现平台适配器接口 +- **完整流式支持** (`reply_message_chunk` 方法) +- 双向消息流处理 +- 消息历史管理 +- 会话管理 + +**3. WebSocket路由控制器** +``` +src/langbot/pkg/api/http/controller/groups/pipelines/websocket_chat.py +``` +- WebSocket端点处理 +- REST API接口 +- 心跳机制 +- 连接生命周期管理 + +**4. 配置文件** +``` +src/langbot/pkg/platform/sources/websocket.yaml +``` +- WebSocket适配器元数据 + +### 前端核心文件 + +**1. WebSocket客户端** +``` +web/src/app/infra/websocket/WebSocketClient.ts +``` +- WebSocket连接管理 +- 自动重连(最多5次) +- 心跳机制(30秒) +- 事件回调系统 + +**2. 更新的组件** +``` +web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx +``` +- 完全重写,使用WebSocket +- 实时连接状态显示 +- 流式消息支持 +- 自动重连 + +**3. HTTP客户端更新** +``` +web/src/app/infra/http/BackendClient.ts +``` +- 移除所有旧的WebChat API +- 仅保留WebSocket API + +### 测试工具 + +**Python测试客户端** +``` +test_websocket_client.py +``` +- 单连接交互测试 +- 多连接并发测试 +- 命令行工具 + +### 文档 + +**使用文档** +``` +WEBSOCKET_README.md +``` +- 完整的API文档 +- 架构说明 +- 使用示例 +- 故障排查 + +## 核心变更 + +### 后端变更 + +**1. botmgr.py** +- ❌ 移除 `webchat_proxy_bot` +- ✅ 仅保留 `websocket_proxy_bot` +- ✅ 更新适配器过滤逻辑(排除`websocket`而非`webchat`) + +**2. 适配器注册** +```python +# 旧代码(已删除) +webchat_adapter_class = self.adapter_dict['webchat'] +self.webchat_proxy_bot = RuntimeBot(...) + +# 新代码 +websocket_adapter_class = self.adapter_dict['websocket'] +self.websocket_proxy_bot = RuntimeBot( + uuid='websocket-proxy-bot', + name='WebSocket', + adapter='websocket', + ... +) +``` + +### 前端变更 + +**1. API调用完全更换** + +旧代码(已删除): +```typescript +// SSE流式请求 +await fetch(url, { + method: 'POST', + body: JSON.stringify({ is_stream: true }) +}) +// 手动解析 text/event-stream +``` + +新代码: +```typescript +// WebSocket实时通信 +const wsClient = new WebSocketClient(pipelineId, sessionType); +await wsClient.connect(); + +wsClient.onMessage((message) => { + // 流式消息自动处理 + setMessages(prev => [...prev, message]); +}); + +wsClient.sendMessage(messageChain); +``` + +**2. 连接状态管理** + +新增功能: +- ✅ 实时连接状态指示器(绿色/红色圆点) +- ✅ 连接/断开toast提示 +- ✅ 自动重连逻辑 +- ✅ 心跳保活 + +**3. 流式支持** + +完整的流式消息处理: +```typescript +wsClient.onMessage((message) => { + if (message.is_final) { + // 最终消息 + finalizeBotMessage(message); + } else { + // 中间消息块,实时更新UI + updateBotMessage(message); + } +}); +``` + +## API对比 + +### WebSocket端点 + +**连接** +``` +ws://localhost:8000/api/v1/pipelines//ws/connect?session_type= +``` + +**消息格式** + +客户端发送: +```json +{ + "type": "message", + "message": [ + {"type": "Plain", "text": "你好"} + ] +} +``` + +服务器响应(流式): +```json +{ + "type": "response", + "data": { + "id": 1, + "role": "assistant", + "content": "你好,我是...", + "is_final": false, + "timestamp": "2025-01-28T..." + } +} +``` + +### REST API + +| 端点 | 方法 | 说明 | +|------|------|------| +| `/api/v1/pipelines//ws/messages/` | GET | 获取消息历史 | +| `/api/v1/pipelines//ws/reset/` | POST | 重置会话 | +| `/api/v1/pipelines//ws/connections` | GET | 获取连接统计 | +| `/api/v1/pipelines//ws/broadcast` | POST | 广播消息 | + +## 流式支持详解 + +### 后端流式实现 + +**WebSocket Adapter** +```python +async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + bot_message, + message: platform_message.MessageChain, + quote_origin: bool = False, + is_final: bool = False, +) -> dict: + """回复消息块 - 流式""" + message_data = WebSocketMessage( + id=-1, + role='assistant', + content=str(message), + message_chain=[component.__dict__ for component in message], + timestamp=datetime.now().isoformat(), + is_final=is_final and bot_message.tool_calls is None, + ) + + # 发送到队列,由WebSocket连接处理发送 + await session.resp_queues[message_id].put(message_data) + return message_data.model_dump() + +async def is_stream_output_supported(self) -> bool: + """WebSocket始终支持流式输出""" + return True +``` + +### 前端流式处理 + +**DebugDialog组件** +```typescript +wsClient.onMessage((message) => { + setMessages((prevMessages) => { + const existingIndex = prevMessages.findIndex( + (msg) => msg.role === 'assistant' && msg.content === 'Generating...' + ); + + if (existingIndex !== -1) { + // 更新正在生成的消息 + const updatedMessages = [...prevMessages]; + updatedMessages[existingIndex] = message; + return updatedMessages; + } else { + // 添加新消息 + return [...prevMessages, message]; + } + }); +}); +``` + +## 兼容性说明 + +### ⚠️ 不兼容旧版本 + +此次迁移**完全不兼容**旧的WebChat系统: + +1. **API端点变更** + - 旧: `/api/v1/pipelines//chat/send` + - 新: `ws://...//ws/connect` + +2. **通信协议变更** + - 旧: HTTP + SSE (Server-Sent Events) + - 新: WebSocket (双向) + +3. **流式实现变更** + - 旧: `text/event-stream` 格式 + - 新: WebSocket JSON消息 + +### 迁移要求 + +使用新系统需要: +1. ✅ 前端必须支持WebSocket +2. ✅ 后端必须运行新的WebSocket适配器 +3. ✅ 清除旧的WebChat相关配置 + +## 优势对比 + +| 特性 | 旧WebChat (SSE) | 新WebSocket | +|------|----------------|-------------| +| 双向通信 | ❌ 单向(服务器→客户端) | ✅ 双向 | +| 主动推送 | ❌ 不支持 | ✅ 支持 | +| 连接管理 | ❌ 无状态 | ✅ 有状态,完整生命周期 | +| 流式输出 | ✅ 支持 | ✅ 支持(更优) | +| 心跳机制 | ❌ 无 | ✅ 30秒心跳 | +| 自动重连 | ❌ 无 | ✅ 最多5次 | +| 多连接 | ⚠️ 难以管理 | ✅ 完整支持 | +| 连接状态 | ❌ 不可见 | ✅ 实时显示 | +| 广播功能 | ❌ 不支持 | ✅ 支持 | + +## 测试方式 + +### 1. Python测试客户端 + +```bash +# 单连接测试 +python test_websocket_client.py + +# 指定会话类型 +python test_websocket_client.py --session-type group + +# 多连接并发测试(5个连接) +python test_websocket_client.py --multi 5 +``` + +### 2. 前端测试 + +1. 启动LangBot服务器 +2. 访问前端界面 +3. 打开流水线调试对话框 +4. 观察连接状态指示器(左下角圆点) +5. 发送消息测试流式响应 + +### 3. 浏览器控制台测试 + +```javascript +const ws = new WebSocket('ws://localhost:8000/api/v1/pipelines//ws/connect?session_type=person'); + +ws.onopen = () => { + console.log('已连接'); + ws.send(JSON.stringify({ + type: 'message', + message: [{type: 'Plain', text: '你好'}] + })); +}; + +ws.onmessage = (event) => { + console.log('收到:', JSON.parse(event.data)); +}; +``` + +## 常见问题 + +### Q: 为什么完全删除旧代码而不保留兼容性? +A: 根据需求,不需要考虑任何对老版本的兼容性,彻底迁移可以避免代码冗余和维护负担。 + +### Q: 流式输出如何工作? +A: +1. 后端通过`reply_message_chunk`发送消息块 +2. 消息块放入队列 +3. WebSocket连接从队列取出并发送 +4. 前端实时更新UI +5. `is_final=true`表示最后一块 + +### Q: 如何确保连接不断开? +A: +1. 客户端每30秒发送心跳(ping) +2. 服务器响应pong +3. 连接断开时自动重连(最多5次) + +### Q: 如何实现后端主动推送? +A: +1. 调用 `/api/v1/pipelines//ws/broadcast` API +2. 消息会被推送到该流水线的所有连接 +3. 前端通过`onBroadcast`回调接收 + +## 总结 + +✅ **完成的工作** +- 完全移除旧的WebChat/SSE系统 +- 实现完整的WebSocket双向通信系统 +- 支持流式输出 +- 支持多连接并发 +- 实现自动重连和心跳机制 +- 提供完整的测试工具和文档 + +✅ **核心特性** +- 双向实时通信 +- 流式消息支持 +- 多连接管理 +- 自动重连 +- 心跳保活 +- 连接状态可视化 +- 广播消息 + +✅ **技术亮点** +- 异步架构(asyncio) +- 线程安全的连接管理 +- 独立的消息队列 +- 完整的错误处理 +- 模块化设计 + +🎉 系统已完全迁移到WebSocket,无任何旧代码遗留! diff --git a/WEBSOCKET_README.md b/WEBSOCKET_README.md new file mode 100644 index 00000000..9e943983 --- /dev/null +++ b/WEBSOCKET_README.md @@ -0,0 +1,394 @@ +# LangBot WebSocket 双向通信系统 + +## 概述 + +这是一个内置在 LangBot 中的完整 IM (即时通讯) 系统,支持: + +- ✅ WebSocket 双向实时通信 +- ✅ 多个客户端并发连接 +- ✅ 前端到后端的消息发送 +- ✅ 后端到前端的主动推送 +- ✅ 流式响应支持 +- ✅ 连接管理和会话隔离 +- ✅ 心跳机制 +- ✅ 广播消息功能 + +## 架构设计 + +### 核心组件 + +1. **WebSocketConnectionManager** (`websocket_manager.py`) + - 管理所有活跃的 WebSocket 连接 + - 支持按流水线、会话类型查询连接 + - 提供广播和单播功能 + - 线程安全的并发访问控制 + +2. **WebSocketAdapter** (`websocket_adapter.py`) + - 实现平台适配器接口 + - 处理消息的接收和发送 + - 支持流式输出 + - 管理消息历史 + +3. **WebSocketChatRouterGroup** (`websocket_chat.py`) + - WebSocket 路由控制器 + - 处理连接建立、消息收发 + - 实现心跳机制 + - 提供 REST API 接口 + +## API 接口 + +### WebSocket 连接 + +#### 建立连接 + +``` +ws://localhost:8000/api/v1/pipelines//ws/connect?session_type= +``` + +**参数:** +- `pipeline_uuid`: 流水线 UUID (必需) +- `session_type`: 会话类型,可选 `person` 或 `group` (默认: `person`) + +**连接成功响应:** +```json +{ + "type": "connected", + "connection_id": "550e8400-e29b-41d4-a716-446655440000", + "pipeline_uuid": "your-pipeline-uuid", + "session_type": "person", + "timestamp": "2025-01-28T12:00:00" +} +``` + +### 消息格式 + +#### 客户端发送消息 + +**发送聊天消息:** +```json +{ + "type": "message", + "message": [ + { + "type": "Plain", + "text": "你好,这是一条测试消息" + } + ] +} +``` + +**发送心跳:** +```json +{ + "type": "ping" +} +``` + +**主动断开连接:** +```json +{ + "type": "disconnect" +} +``` + +#### 服务器响应消息 + +**聊天响应 (流式):** +```json +{ + "type": "response", + "data": { + "id": 1, + "role": "assistant", + "content": "这是机器人的回复", + "message_chain": [...], + "timestamp": "2025-01-28T12:00:00", + "is_final": false, + "connection_id": "..." + } +} +``` + +**心跳响应:** +```json +{ + "type": "pong", + "timestamp": "2025-01-28T12:00:00" +} +``` + +**广播消息:** +```json +{ + "type": "broadcast", + "message": "这是一条广播消息", + "timestamp": "2025-01-28T12:00:00" +} +``` + +**错误消息:** +```json +{ + "type": "error", + "message": "错误描述" +} +``` + +### REST API 接口 + +#### 1. 获取消息历史 + +```http +GET /api/v1/pipelines//ws/messages/ +``` + +**响应:** +```json +{ + "code": 0, + "msg": "ok", + "data": { + "messages": [...] + } +} +``` + +#### 2. 重置会话 + +```http +POST /api/v1/pipelines//ws/reset/ +``` + +**响应:** +```json +{ + "code": 0, + "msg": "ok", + "data": { + "message": "Session reset successfully" + } +} +``` + +#### 3. 获取连接统计 + +```http +GET /api/v1/pipelines//ws/connections +``` + +**响应:** +```json +{ + "code": 0, + "msg": "ok", + "data": { + "stats": { + "total_connections": 5, + "pipelines": 2, + "connections_by_pipeline": { + "pipeline-1": 3, + "pipeline-2": 2 + }, + "connections_by_session_type": { + "person": 4, + "group": 1 + } + }, + "connections": [ + { + "connection_id": "...", + "session_type": "person", + "created_at": "2025-01-28T12:00:00", + "last_active": "2025-01-28T12:05:00", + "is_active": true + } + ] + } +} +``` + +#### 4. 广播消息 (后端主动推送) + +```http +POST /api/v1/pipelines//ws/broadcast +Content-Type: application/json + +{ + "message": "这是一条广播消息" +} +``` + +**响应:** +```json +{ + "code": 0, + "msg": "ok", + "data": { + "message": "Broadcast sent successfully" + } +} +``` + +## 使用示例 + +### Python 客户端示例 + +使用提供的测试客户端: + +```bash +# 安装依赖 +pip install websockets + +# 单个连接测试 +python test_websocket_client.py + +# 指定会话类型 +python test_websocket_client.py --session-type group + +# 多连接并发测试 +python test_websocket_client.py --multi 5 +``` + +### JavaScript 客户端示例 + +```javascript +// 建立 WebSocket 连接 +const ws = new WebSocket('ws://localhost:8000/api/v1/pipelines/your-pipeline-uuid/ws/connect?session_type=person'); + +// 连接建立 +ws.onopen = () => { + console.log('WebSocket 连接已建立'); + + // 发送消息 + ws.send(JSON.stringify({ + type: 'message', + message: [ + { + type: 'Plain', + text: '你好' + } + ] + })); +}; + +// 接收消息 +ws.onmessage = (event) => { + const data = JSON.parse(event.data); + + if (data.type === 'connected') { + console.log('连接成功:', data.connection_id); + } else if (data.type === 'response') { + console.log('机器人回复:', data.data.content); + if (data.data.is_final) { + console.log('响应完成'); + } + } else if (data.type === 'broadcast') { + console.log('收到广播:', data.message); + } +}; + +// 连接关闭 +ws.onclose = () => { + console.log('WebSocket 连接已关闭'); +}; + +// 错误处理 +ws.onerror = (error) => { + console.error('WebSocket 错误:', error); +}; + +// 发送心跳 +setInterval(() => { + if (ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ type: 'ping' })); + } +}, 30000); // 每 30 秒发送一次心跳 +``` + +## 特性说明 + +### 1. 多连接支持 + +系统支持同时建立多个 WebSocket 连接,每个连接都有唯一的 `connection_id`。连接按照流水线和会话类型进行分组管理。 + +### 2. 双向通信 + +- **前端 → 后端**: 客户端可以主动发送消息给服务器 +- **后端 → 前端**: 服务器可以通过广播 API 主动推送消息给客户端 + +### 3. 流式响应 + +支持流式输出,机器人的响应会分块发送,客户端可以实时显示部分响应内容。 + +### 4. 会话隔离 + +支持 `person` 和 `group` 两种会话类型,不同类型的会话消息历史互不影响。 + +### 5. 连接管理 + +- 自动追踪连接状态 +- 记录最后活跃时间 +- 支持连接统计查询 +- 连接断开时自动清理资源 + +### 6. 心跳机制 + +客户端可以定期发送 `ping` 消息,服务器会响应 `pong`,用于保持连接活跃和检测连接状态。 + +## 架构优势 + +1. **高并发**: 使用 asyncio 异步架构,支持大量并发连接 +2. **可扩展**: 模块化设计,易于扩展新功能 +3. **线程安全**: 连接管理器使用锁机制保证并发安全 +4. **消息队列**: 每个连接独立的发送队列,避免消息混乱 +5. **灵活路由**: 支持按流水线、会话类型灵活路由消息 + +## 注意事项 + +1. **认证**: 当前 WebSocket 连接不需要认证,生产环境建议添加认证机制 +2. **心跳**: 建议客户端实现心跳机制,避免连接超时 +3. **重连**: 客户端应实现断线重连逻辑 +4. **消息大小**: 注意控制单条消息大小,避免内存溢出 +5. **连接数限制**: 生产环境建议设置最大连接数限制 + +## 故障排查 + +### 连接失败 + +1. 检查流水线 UUID 是否正确 +2. 检查服务器是否正常运行 +3. 检查防火墙设置 + +### 消息发送失败 + +1. 检查消息格式是否正确 +2. 检查连接是否仍然活跃 +3. 查看服务器日志获取详细错误信息 + +### 性能问题 + +1. 检查并发连接数是否过多 +2. 检查消息处理速度 +3. 考虑使用连接池或负载均衡 + +## 开发调试 + +启用详细日志: + +```python +import logging +logging.getLogger('langbot.pkg.platform.sources.websocket_adapter').setLevel(logging.DEBUG) +logging.getLogger('langbot.pkg.platform.sources.websocket_manager').setLevel(logging.DEBUG) +logging.getLogger('langbot.pkg.api.http.controller.groups.pipelines.websocket_chat').setLevel(logging.DEBUG) +``` + +## 后续改进建议 + +1. 添加用户认证和授权机制 +2. 实现消息持久化 +3. 添加消息加密 +4. 实现更丰富的消息类型 (图片、文件等) +5. 添加消息已读/未读状态 +6. 实现群组聊天功能 +7. 添加在线状态显示 +8. 实现消息撤回功能 diff --git a/src/langbot/pkg/api/http/controller/groups/files.py b/src/langbot/pkg/api/http/controller/groups/files.py index 05877e14..30aab87b 100644 --- a/src/langbot/pkg/api/http/controller/groups/files.py +++ b/src/langbot/pkg/api/http/controller/groups/files.py @@ -28,8 +28,56 @@ class FilesRouterGroup(group.RouterGroup): return quart.Response(image_bytes, mimetype=mime_type) + @self.route('/images', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def upload_image() -> quart.Response: + request = quart.request + + # Check file size limit before reading the file + content_length = request.content_length + if content_length and content_length > group.MAX_FILE_SIZE: + return self.fail(400, 'Image size exceeds 10MB limit.') + + # get file bytes from 'file' + files = await request.files + if 'file' not in files: + return self.fail(400, 'No image file provided') + + file = files['file'] + assert isinstance(file, quart.datastructures.FileStorage) + + file_bytes = await asyncio.to_thread(file.stream.read) + + # Double-check actual file size after reading + if len(file_bytes) > group.MAX_FILE_SIZE: + return self.fail(400, 'Image size exceeds 10MB limit.') + + # Validate image file extension + allowed_extensions = {'jpg', 'jpeg', 'png', 'gif', 'webp'} + if '.' in file.filename: + file_name, extension = file.filename.rsplit('.', 1) + extension = extension.lower() + else: + return self.fail(400, 'Invalid image file: no file extension') + + if extension not in allowed_extensions: + return self.fail(400, f'Invalid image format. Allowed formats: {", ".join(allowed_extensions)}') + + # check if file name contains '/' or '\' + if '/' in file_name or '\\' in file_name: + return self.fail(400, 'File name contains invalid characters') + + file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension + + # save file to storage + await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes) + return self.success( + data={ + 'file_key': file_key, + } + ) + @self.route('/documents', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) - async def _() -> quart.Response: + async def upload_document() -> quart.Response: request = quart.request # Check file size limit before reading the file diff --git a/src/langbot/pkg/api/http/controller/groups/pipelines/webchat.py b/src/langbot/pkg/api/http/controller/groups/pipelines/webchat.py deleted file mode 100644 index 13f955d8..00000000 --- a/src/langbot/pkg/api/http/controller/groups/pipelines/webchat.py +++ /dev/null @@ -1,109 +0,0 @@ -import json - -import quart - -from ... import group - - -@group.group_class('webchat', '/api/v1/pipelines//chat') -class WebChatDebugRouterGroup(group.RouterGroup): - async def initialize(self) -> None: - @self.route('/send', methods=['POST']) - async def send_message(pipeline_uuid: str) -> str: - """Send a message to the pipeline for debugging""" - - async def stream_generator(generator): - yield 'data: {"type": "start"}\n\n' - async for message in generator: - yield f'data: {json.dumps({"message": message})}\n\n' - yield 'data: {"type": "end"}\n\n' - - try: - data = await quart.request.get_json() - session_type = data.get('session_type', 'person') - message_chain_obj = data.get('message', []) - is_stream = data.get('is_stream', False) - - if not message_chain_obj: - return self.http_status(400, -1, 'message is required') - - if session_type not in ['person', 'group']: - return self.http_status(400, -1, 'session_type must be person or group') - - webchat_adapter = self.ap.platform_mgr.webchat_proxy_bot.adapter - - if not webchat_adapter: - return self.http_status(404, -1, 'WebChat adapter not found') - - if is_stream: - generator = webchat_adapter.send_webchat_message( - pipeline_uuid, session_type, message_chain_obj, is_stream - ) - # 设置正确的响应头 - headers = { - 'Content-Type': 'text/event-stream', - 'Transfer-Encoding': 'chunked', - 'Cache-Control': 'no-cache', - 'Connection': 'keep-alive', - } - return quart.Response(stream_generator(generator), mimetype='text/event-stream', headers=headers) - - else: # non-stream - result = None - async for message in webchat_adapter.send_webchat_message( - pipeline_uuid, session_type, message_chain_obj - ): - result = message - if result is not None: - return self.success( - data={ - 'message': result, - } - ) - else: - return self.http_status(400, -1, 'message is required') - - except Exception as e: - return self.http_status(500, -1, f'Internal server error: {str(e)}') - - @self.route('/messages/', methods=['GET']) - async def get_messages(pipeline_uuid: str, session_type: str) -> str: - """Get the message history of the pipeline for debugging""" - try: - if session_type not in ['person', 'group']: - return self.http_status(400, -1, 'session_type must be person or group') - - webchat_adapter = self.ap.platform_mgr.webchat_proxy_bot.adapter - - if not webchat_adapter: - return self.http_status(404, -1, 'WebChat adapter not found') - - messages = webchat_adapter.get_webchat_messages(pipeline_uuid, session_type) - - return self.success(data={'messages': messages}) - - except Exception as e: - return self.http_status(500, -1, f'Internal server error: {str(e)}') - - @self.route('/reset/', methods=['POST']) - async def reset_session(session_type: str) -> str: - """Reset the debug session""" - try: - if session_type not in ['person', 'group']: - return self.http_status(400, -1, 'session_type must be person or group') - - webchat_adapter = None - for bot in self.ap.platform_mgr.bots: - if hasattr(bot.adapter, '__class__') and bot.adapter.__class__.__name__ == 'WebChatAdapter': - webchat_adapter = bot.adapter - break - - if not webchat_adapter: - return self.http_status(404, -1, 'WebChat adapter not found') - - webchat_adapter.reset_debug_session(session_type) - - return self.success(data={'message': 'Session reset successfully'}) - - except Exception as e: - return self.http_status(500, -1, f'Internal server error: {str(e)}') diff --git a/src/langbot/pkg/api/http/controller/groups/pipelines/websocket_chat.py b/src/langbot/pkg/api/http/controller/groups/pipelines/websocket_chat.py new file mode 100644 index 00000000..d03790b1 --- /dev/null +++ b/src/langbot/pkg/api/http/controller/groups/pipelines/websocket_chat.py @@ -0,0 +1,243 @@ +"""WebSocket聊天路由 - 支持双向实时通信""" + +import asyncio +import datetime +import json +import logging + +import quart + +from ... import group +from ......platform.sources.websocket_manager import ws_connection_manager + +logger = logging.getLogger(__name__) + + +@group.group_class('websocket_chat', '/api/v1/pipelines//ws') +class WebSocketChatRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + # 直接使用 quart_app 注册 WebSocket 路由 + @self.quart_app.websocket(self.path + '/connect') + async def websocket_connect(pipeline_uuid: str): + """ + 建立WebSocket连接 + + URL参数: + - pipeline_uuid: 流水线UUID + - session_type: 会话类型 (person/group) + """ + try: + # 获取参数 - 在WebSocket上下文中使用 quart.websocket.args + session_type = quart.websocket.args.get('session_type', 'person') + + if session_type not in ['person', 'group']: + await quart.websocket.send( + json.dumps({'type': 'error', 'message': 'session_type must be person or group'}) + ) + return + + # 获取WebSocket适配器 + websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter + + if not websocket_adapter: + await quart.websocket.send(json.dumps({'type': 'error', 'message': 'WebSocket adapter not found'})) + return + + # 注册连接 + connection = await ws_connection_manager.add_connection( + websocket=quart.websocket._get_current_object(), + pipeline_uuid=pipeline_uuid, + session_type=session_type, + metadata={'user_agent': quart.websocket.headers.get('User-Agent', '')}, + ) + + # 发送连接成功消息 + await quart.websocket.send( + json.dumps( + { + 'type': 'connected', + 'connection_id': connection.connection_id, + 'pipeline_uuid': pipeline_uuid, + 'session_type': session_type, + 'timestamp': connection.created_at.isoformat(), + } + ) + ) + + logger.debug( + f'WebSocket connection established: {connection.connection_id} ' + f'(pipeline={pipeline_uuid}, session_type={session_type})' + ) + + # 创建接收和发送任务 + receive_task = asyncio.create_task(self._handle_receive(connection, websocket_adapter)) + send_task = asyncio.create_task(self._handle_send(connection)) + + # 等待任务完成 + try: + await asyncio.gather(receive_task, send_task) + except Exception as e: + logger.error(f'WebSocket task execution error: {e}') + finally: + # 清理连接 + await ws_connection_manager.remove_connection(connection.connection_id) + logger.debug(f'WebSocket connection cleaned: {connection.connection_id}') + + except Exception as e: + logger.error(f'WebSocket connection error: {e}', exc_info=True) + try: + await quart.websocket.send(json.dumps({'type': 'error', 'message': str(e)})) + except: + pass + + @self.route('/messages/', methods=['GET']) + async def get_messages(pipeline_uuid: str, session_type: str) -> str: + """获取消息历史""" + try: + if session_type not in ['person', 'group']: + return self.http_status(400, -1, 'session_type must be person or group') + + websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter + + if not websocket_adapter: + return self.http_status(404, -1, 'WebSocket adapter not found') + + messages = websocket_adapter.get_websocket_messages(pipeline_uuid, session_type) + + return self.success(data={'messages': messages}) + + except Exception as e: + return self.http_status(500, -1, f'Internal server error: {str(e)}') + + @self.route('/reset/', methods=['POST']) + async def reset_session(pipeline_uuid: str, session_type: str) -> str: + """重置会话""" + try: + if session_type not in ['person', 'group']: + return self.http_status(400, -1, 'session_type must be person or group') + + websocket_adapter = self.ap.platform_mgr.websocket_proxy_bot.adapter + + if not websocket_adapter: + return self.http_status(404, -1, 'WebSocket adapter not found') + + websocket_adapter.reset_session(pipeline_uuid, session_type) + + return self.success(data={'message': 'Session reset successfully'}) + + except Exception as e: + return self.http_status(500, -1, f'Internal server error: {str(e)}') + + @self.route('/connections', methods=['GET']) + async def get_connections(pipeline_uuid: str) -> str: + """获取当前连接统计""" + try: + stats = ws_connection_manager.get_stats() + connections = await ws_connection_manager.get_connections_by_pipeline(pipeline_uuid) + + return self.success( + data={ + 'stats': stats, + 'connections': [ + { + 'connection_id': conn.connection_id, + 'session_type': conn.session_type, + 'created_at': conn.created_at.isoformat(), + 'last_active': conn.last_active.isoformat(), + 'is_active': conn.is_active, + } + for conn in connections + ], + } + ) + + except Exception as e: + return self.http_status(500, -1, f'Internal server error: {str(e)}') + + @self.route('/broadcast', methods=['POST']) + async def broadcast_message(pipeline_uuid: str) -> str: + """向所有连接广播消息(后端主动推送)""" + try: + data = await quart.request.get_json() + message = data.get('message') + + if not message: + return self.http_status(400, -1, 'message is required') + + # 广播消息 + broadcast_data = { + 'type': 'broadcast', + 'message': message, + 'timestamp': datetime.datetime.now().isoformat(), + } + + await ws_connection_manager.broadcast_to_pipeline(pipeline_uuid, broadcast_data) + + return self.success(data={'message': 'Broadcast sent successfully'}) + + except Exception as e: + return self.http_status(500, -1, f'Internal server error: {str(e)}') + + async def _handle_receive(self, connection, websocket_adapter): + """处理接收消息的任务""" + try: + while connection.is_active: + # 接收消息 + message = await quart.websocket.receive() + + # 更新活跃时间 + await ws_connection_manager.update_activity(connection.connection_id) + + try: + data = json.loads(message) + message_type = data.get('type', 'message') + + if message_type == 'ping': + # 心跳响应 + await connection.send_queue.put( + {'type': 'pong', 'timestamp': datetime.datetime.now().isoformat()} + ) + + elif message_type == 'message': + # 处理用户消息 + logger.debug(f'收到消息: {data} from {connection.connection_id}') + + # 处理消息(不等待响应,响应会通过broadcast异步发送) + await websocket_adapter.handle_websocket_message(connection, data) + + elif message_type == 'disconnect': + # 客户端主动断开 + logger.debug(f'Client disconnected: {connection.connection_id}') + break + + else: + logger.warning(f'Unknown message type: {message_type}') + + except json.JSONDecodeError: + logger.error(f'Invalid JSON message: {message}') + await connection.send_queue.put({'type': 'error', 'message': 'Invalid JSON format'}) + + except Exception as e: + logger.error(f'Receive message error: {e}', exc_info=True) + finally: + connection.is_active = False + + async def _handle_send(self, connection): + """处理发送消息的任务""" + try: + while connection.is_active: + # 从队列获取消息 + try: + message = await asyncio.wait_for(connection.send_queue.get(), timeout=1.0) + + # 发送消息 + await quart.websocket.send(json.dumps(message)) + + except asyncio.TimeoutError: + # 超时继续循环 + continue + + except Exception as e: + logger.error(f'Send message error: {e}', exc_info=True) + finally: + connection.is_active = False diff --git a/src/langbot/pkg/platform/botmgr.py b/src/langbot/pkg/platform/botmgr.py index 73c59da4..44305c92 100644 --- a/src/langbot/pkg/platform/botmgr.py +++ b/src/langbot/pkg/platform/botmgr.py @@ -156,7 +156,7 @@ class PlatformManager: bots: list[RuntimeBot] - webchat_proxy_bot: RuntimeBot + websocket_proxy_bot: RuntimeBot adapter_components: list[engine.Component] @@ -178,31 +178,29 @@ class PlatformManager: adapter_dict[component.metadata.name] = component.get_python_component_class() self.adapter_dict = adapter_dict - webchat_adapter_class = self.adapter_dict['webchat'] - - # initialize webchat adapter - webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap) - webchat_adapter_inst = webchat_adapter_class( + # initialize websocket adapter + websocket_adapter_class = self.adapter_dict['websocket'] + websocket_logger = EventLogger(name='websocket-adapter', ap=self.ap) + websocket_adapter_inst = websocket_adapter_class( {}, - webchat_logger, + websocket_logger, ap=self.ap, - is_stream=False, ) - self.webchat_proxy_bot = RuntimeBot( + self.websocket_proxy_bot = RuntimeBot( ap=self.ap, bot_entity=persistence_bot.Bot( - uuid='webchat-proxy-bot', - name='WebChat', + uuid='websocket-proxy-bot', + name='WebSocket', description='', - adapter='webchat', + adapter='websocket', adapter_config={}, enable=True, ), - adapter=webchat_adapter_inst, - logger=webchat_logger, + adapter=websocket_adapter_inst, + logger=websocket_logger, ) - await self.webchat_proxy_bot.initialize() + await self.websocket_proxy_bot.initialize() await self.load_bots_from_db() @@ -271,7 +269,7 @@ class PlatformManager: def get_available_adapters_info(self) -> list[dict]: return [ - component.to_plain_dict() for component in self.adapter_components if component.metadata.name != 'webchat' + component.to_plain_dict() for component in self.adapter_components if component.metadata.name != 'websocket' ] def get_available_adapter_info_by_name(self, name: str) -> dict | None: @@ -288,7 +286,7 @@ class PlatformManager: async def run(self): # This method will only be called when the application launching - await self.webchat_proxy_bot.run() + await self.websocket_proxy_bot.run() for bot in self.bots: if bot.enable: diff --git a/src/langbot/pkg/platform/sources/webchat.py b/src/langbot/pkg/platform/sources/webchat.py deleted file mode 100644 index 7fd54d1e..00000000 --- a/src/langbot/pkg/platform/sources/webchat.py +++ /dev/null @@ -1,304 +0,0 @@ -import asyncio -import logging -import typing -from datetime import datetime - -import pydantic - -import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter -import langbot_plugin.api.entities.builtin.platform.message as platform_message -import langbot_plugin.api.entities.builtin.platform.events as platform_events -import langbot_plugin.api.entities.builtin.platform.entities as platform_entities -import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger -from ...core import app - -logger = logging.getLogger(__name__) - - -class WebChatMessage(pydantic.BaseModel): - id: int - role: str - content: str - message_chain: list[dict] - timestamp: str - is_final: bool = False - - -class WebChatSession: - id: str - message_lists: dict[str, list[WebChatMessage]] = {} - resp_waiters: dict[int, asyncio.Future[WebChatMessage]] - resp_queues: dict[int, asyncio.Queue[WebChatMessage]] - - def __init__(self, id: str): - self.id = id - self.message_lists = {} - self.resp_waiters = {} - self.resp_queues = {} - - def get_message_list(self, pipeline_uuid: str) -> list[WebChatMessage]: - if pipeline_uuid not in self.message_lists: - self.message_lists[pipeline_uuid] = [] - - return self.message_lists[pipeline_uuid] - - -class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): - """WebChat调试适配器,用于流水线调试""" - - webchat_person_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession) - webchat_group_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession) - - listeners: dict[ - typing.Type[platform_events.Event], - typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None], - ] = pydantic.Field(default_factory=dict, exclude=True) - - is_stream: bool = pydantic.Field(exclude=True) - debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True) - - ap: app.Application = pydantic.Field(exclude=True) - - def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs): - super().__init__( - config=config, - logger=logger, - **kwargs, - ) - - self.webchat_person_session = WebChatSession(id='webchatperson') - self.webchat_group_session = WebChatSession(id='webchatgroup') - - self.bot_account_id = 'webchatbot' - - self.debug_messages = {} - - async def send_message( - self, - target_type: str, - target_id: str, - message: platform_message.MessageChain, - ) -> dict: - """发送消息到调试会话""" - session_key = target_id - - if session_key not in self.debug_messages: - self.debug_messages[session_key] = [] - - message_data = { - 'id': len(self.debug_messages[session_key]) + 1, - 'type': 'bot', - 'content': str(message), - 'timestamp': datetime.now().isoformat(), - 'message_chain': [component.__dict__ for component in message], - } - - self.debug_messages[session_key].append(message_data) - - await self.logger.info(f'Send message to {session_key}: {message}') - - return message_data - - async def reply_message( - self, - message_source: platform_events.MessageEvent, - message: platform_message.MessageChain, - quote_origin: bool = False, - ) -> dict: - """回复消息""" - message_data = WebChatMessage( - id=-1, - role='assistant', - content=str(message), - message_chain=[component.__dict__ for component in message], - timestamp=datetime.now().isoformat(), - ) - - # notify waiter - if isinstance(message_source, platform_events.FriendMessage): - await self.webchat_person_session.resp_queues[message_source.message_chain.message_id].put(message_data) - elif isinstance(message_source, platform_events.GroupMessage): - await self.webchat_group_session.resp_queues[message_source.message_chain.message_id].put(message_data) - - return message_data.model_dump() - - async def reply_message_chunk( - self, - message_source: platform_events.MessageEvent, - bot_message, - message: platform_message.MessageChain, - quote_origin: bool = False, - is_final: bool = False, - ) -> dict: - """回复消息""" - message_data = WebChatMessage( - id=-1, - role='assistant', - content=str(message), - message_chain=[component.__dict__ for component in message], - timestamp=datetime.now().isoformat(), - ) - - # notify waiter - session = ( - self.webchat_group_session - if isinstance(message_source, platform_events.GroupMessage) - else self.webchat_person_session - ) - if message_source.message_chain.message_id not in session.resp_waiters: - # session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue() - queue = session.resp_queues[message_source.message_chain.message_id] - - # if isinstance(message_source, platform_events.FriendMessage): - # queue = self.webchat_person_session.resp_queues[message_source.message_chain.message_id] - # elif isinstance(message_source, platform_events.GroupMessage): - # queue = self.webchat_group_session.resp_queues[message_source.message_chain.message_id] - if is_final and bot_message.tool_calls is None: - message_data.is_final = True - # print(message_data) - await queue.put(message_data) - - return message_data.model_dump() - - async def is_stream_output_supported(self) -> bool: - return self.is_stream - - def register_listener( - self, - event_type: typing.Type[platform_events.Event], - func: typing.Callable[ - [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None] - ], - ): - """注册事件监听器""" - self.listeners[event_type] = func - - def unregister_listener( - self, - event_type: typing.Type[platform_events.Event], - func: typing.Callable[ - [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None] - ], - ): - """取消注册事件监听器""" - del self.listeners[event_type] - - async def is_muted(self, group_id: int) -> bool: - return False - - async def run_async(self): - """运行适配器""" - await self.logger.info('WebChat调试适配器已启动') - - try: - while True: - await asyncio.sleep(1) - except asyncio.CancelledError: - await self.logger.info('WebChat调试适配器已停止') - raise - - async def kill(self): - """停止适配器""" - await self.logger.info('WebChat调试适配器正在停止') - - async def send_webchat_message( - self, - pipeline_uuid: str, - session_type: str, - message_chain_obj: typing.List[dict], - is_stream: bool = False, - ) -> dict: - self.is_stream = is_stream - """发送调试消息到流水线""" - if session_type == 'person': - use_session = self.webchat_person_session - else: - use_session = self.webchat_group_session - - message_chain = platform_message.MessageChain.parse_obj(message_chain_obj) - - message_id = len(use_session.get_message_list(pipeline_uuid)) + 1 - - use_session.resp_queues[message_id] = asyncio.Queue() - logger.debug(f'Initialized queue for message_id: {message_id}') - - use_session.get_message_list(pipeline_uuid).append( - WebChatMessage( - id=message_id, - role='user', - content=str(message_chain), - message_chain=message_chain_obj, - timestamp=datetime.now().isoformat(), - ) - ) - - message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp())) - - if session_type == 'person': - sender = platform_entities.Friend(id='webchatperson', nickname='User', remark='User') - event = platform_events.FriendMessage( - sender=sender, message_chain=message_chain, time=datetime.now().timestamp() - ) - else: - group = platform_entities.Group( - id='webchatgroup', name='Group', permission=platform_entities.Permission.Member - ) - sender = platform_entities.GroupMember( - id='webchatperson', - member_name='User', - group=group, - permission=platform_entities.Permission.Member, - ) - event = platform_events.GroupMessage( - sender=sender, message_chain=message_chain, time=datetime.now().timestamp() - ) - - self.ap.platform_mgr.webchat_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid - - # trigger pipeline - if event.__class__ in self.listeners: - await self.listeners[event.__class__](event, self) - - if is_stream: - queue = use_session.resp_queues[message_id] - msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1 - while True: - resp_message = await queue.get() - resp_message.id = msg_id - if resp_message.is_final: - resp_message.id = msg_id - use_session.get_message_list(pipeline_uuid).append(resp_message) - yield resp_message.model_dump() - break - yield resp_message.model_dump() - use_session.resp_queues.pop(message_id) - - else: # non-stream - # set waiter - # waiter = asyncio.Future[WebChatMessage]() - # use_session.resp_waiters[message_id] = waiter - # # waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id)) - # - # resp_message = await waiter - # - # resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1 - # - # use_session.get_message_list(pipeline_uuid).append(resp_message) - # - # yield resp_message.model_dump() - msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1 - - queue = use_session.resp_queues[message_id] - resp_message = await queue.get() - use_session.get_message_list(pipeline_uuid).append(resp_message) - resp_message.id = msg_id - resp_message.is_final = True - - yield resp_message.model_dump() - - def get_webchat_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]: - """获取调试消息历史""" - if session_type == 'person': - return [message.model_dump() for message in self.webchat_person_session.get_message_list(pipeline_uuid)] - else: - return [message.model_dump() for message in self.webchat_group_session.get_message_list(pipeline_uuid)] diff --git a/src/langbot/pkg/platform/sources/webchat.yaml b/src/langbot/pkg/platform/sources/webchat.yaml deleted file mode 100644 index 748dfc8c..00000000 --- a/src/langbot/pkg/platform/sources/webchat.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: v1 -kind: MessagePlatformAdapter -metadata: - name: webchat - label: - en_US: "WebChat Debug" - zh_Hans: "网页聊天调试" - description: - en_US: "WebChat adapter for pipeline debugging" - zh_Hans: "用于流水线调试的网页聊天适配器" - icon: "" -spec: - config: [] -execution: - python: - path: "webchat.py" - attr: "WebChatAdapter" diff --git a/src/langbot/pkg/platform/sources/websocket.yaml b/src/langbot/pkg/platform/sources/websocket.yaml new file mode 100644 index 00000000..a358a030 --- /dev/null +++ b/src/langbot/pkg/platform/sources/websocket.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: MessagePlatformAdapter +metadata: + name: websocket + label: + en_US: "WebSocket Chat" + zh_Hans: "WebSocket 聊天" + description: + en_US: "WebSocket adapter for bidirectional real-time communication" + zh_Hans: "用于双向实时通信的 WebSocket 适配器" + icon: "" +spec: + config: [] +execution: + python: + path: "websocket_adapter.py" + attr: "WebSocketAdapter" diff --git a/src/langbot/pkg/platform/sources/websocket_adapter.py b/src/langbot/pkg/platform/sources/websocket_adapter.py new file mode 100644 index 00000000..6e03a699 --- /dev/null +++ b/src/langbot/pkg/platform/sources/websocket_adapter.py @@ -0,0 +1,402 @@ +"""WebSocket适配器 - 支持双向通信的IM系统""" + +import asyncio +import logging +import typing +from datetime import datetime + +import pydantic + +import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger +from ...core import app +from .websocket_manager import ws_connection_manager, WebSocketConnection + +logger = logging.getLogger(__name__) + + +class WebSocketMessage(pydantic.BaseModel): + """WebSocket消息格式""" + + id: int + role: str # 'user' or 'assistant' + content: str + message_chain: list[dict] + timestamp: str + is_final: bool = False + connection_id: str = '' + """发送者连接ID""" + + +class WebSocketSession: + """WebSocket会话 - 管理单个会话的消息历史""" + + id: str + message_lists: dict[str, list[WebSocketMessage]] = {} + """消息列表 {pipeline_uuid: [messages]}""" + + def __init__(self, id: str): + self.id = id + self.message_lists = {} + + def get_message_list(self, pipeline_uuid: str) -> list[WebSocketMessage]: + if pipeline_uuid not in self.message_lists: + self.message_lists[pipeline_uuid] = [] + return self.message_lists[pipeline_uuid] + + +class WebSocketAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): + """WebSocket适配器 - 支持双向实时通信""" + + websocket_person_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession) + websocket_group_session: WebSocketSession = pydantic.Field(exclude=True, default_factory=WebSocketSession) + + listeners: dict[ + typing.Type[platform_events.Event], + typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None], + ] = pydantic.Field(default_factory=dict, exclude=True) + + ap: app.Application = pydantic.Field(exclude=True) + + # 主动推送消息的队列 + outbound_message_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True) + """后端主动推送消息的队列""" + + def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs): + super().__init__( + config=config, + logger=logger, + **kwargs, + ) + + self.websocket_person_session = WebSocketSession(id='websocketperson') + self.websocket_group_session = WebSocketSession(id='websocketgroup') + + self.bot_account_id = 'websocketbot' + self.outbound_message_queue = asyncio.Queue() + + async def send_message( + self, + target_type: str, + target_id: str, + message: platform_message.MessageChain, + ) -> dict: + """发送消息 - 这里用于主动推送消息到前端""" + message_data = { + 'type': 'bot_message', + 'target_type': target_type, + 'target_id': target_id, + 'content': str(message), + 'message_chain': [component.__dict__ for component in message], + 'timestamp': datetime.now().isoformat(), + } + + # 推送到所有相关连接 + await self.outbound_message_queue.put(message_data) + + await self.logger.info(f'Send message to {target_id}: {message}') + + return message_data + + async def reply_message( + self, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, + quote_origin: bool = False, + ) -> dict: + """回复消息 - 非流式""" + # 获取会话和pipeline信息 + session = ( + self.websocket_group_session + if isinstance(message_source, platform_events.GroupMessage) + else self.websocket_person_session + ) + + # 从message_source获取pipeline_uuid和connection_id + pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid + # session_type = 'group' if isinstance(message_source, platform_events.GroupMessage) else 'person' + + # 生成新的消息ID + msg_id = len(session.get_message_list(pipeline_uuid)) + 1 + + message_data = WebSocketMessage( + id=msg_id, + role='assistant', + content=str(message), + message_chain=[component.__dict__ for component in message], + timestamp=datetime.now().isoformat(), + is_final=True, + ) + + # 保存到历史记录 + session.get_message_list(pipeline_uuid).append(message_data) + + # 直接广播到所有该pipeline的连接 + await ws_connection_manager.broadcast_to_pipeline( + pipeline_uuid, + { + 'type': 'response', + 'data': message_data.model_dump(), + }, + ) + + return message_data.model_dump() + + async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + bot_message, + message: platform_message.MessageChain, + quote_origin: bool = False, + is_final: bool = False, + ) -> dict: + """回复消息块 - 流式""" + # 获取会话和pipeline信息 + session = ( + self.websocket_group_session + if isinstance(message_source, platform_events.GroupMessage) + else self.websocket_person_session + ) + + pipeline_uuid = self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid + message_list = session.get_message_list(pipeline_uuid) + + # 检查是否是新的流式消息(通过bot_message对象判断) + # 如果列表为空,或者最后一条消息已经is_final=True,则创建新消息 + if not message_list or message_list[-1].is_final: + # 创建新消息 + msg_id = len(message_list) + 1 + message_data = WebSocketMessage( + id=msg_id, + role='assistant', + content=str(message), + message_chain=[component.__dict__ for component in message], + timestamp=datetime.now().isoformat(), + is_final=is_final and bot_message.tool_calls is None, + ) + + # 只有在is_final时才保存到历史记录 + if is_final and bot_message.tool_calls is None: + message_list.append(message_data) + else: + # 更新最后一条消息 + msg_id = message_list[-1].id + message_data = WebSocketMessage( + id=msg_id, + role='assistant', + content=str(message), + message_chain=[component.__dict__ for component in message], + timestamp=message_list[-1].timestamp, # 保持原始时间戳 + is_final=is_final and bot_message.tool_calls is None, + ) + + # 如果是final,更新历史记录中的最后一条 + if is_final and bot_message.tool_calls is None: + message_list[-1] = message_data + + # 直接广播到所有该pipeline的连接 + await ws_connection_manager.broadcast_to_pipeline( + pipeline_uuid, + { + 'type': 'response', + 'data': message_data.model_dump(), + }, + ) + + return message_data.model_dump() + + async def is_stream_output_supported(self) -> bool: + """WebSocket始终支持流式输出""" + return True + + def register_listener( + self, + event_type: typing.Type[platform_events.Event], + func: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None] + ], + ): + """注册事件监听器""" + self.listeners[event_type] = func + + def unregister_listener( + self, + event_type: typing.Type[platform_events.Event], + func: typing.Callable[ + [platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None] + ], + ): + """取消注册事件监听器""" + del self.listeners[event_type] + + async def is_muted(self, group_id: int) -> bool: + return False + + async def run_async(self): + """运行适配器""" + await self.logger.info('WebSocket适配器已启动') + + try: + while True: + # 处理主动推送消息 + if not self.outbound_message_queue.empty(): + try: + message = await asyncio.wait_for(self.outbound_message_queue.get(), timeout=0.1) + # 广播到所有相关连接 + target_id = message.get('target_id', '') + await ws_connection_manager.broadcast_to_pipeline(target_id, message) + except asyncio.TimeoutError: + pass + + await asyncio.sleep(0.1) + except asyncio.CancelledError: + await self.logger.info('WebSocket适配器已停止') + raise + + async def kill(self): + """停止适配器""" + await self.logger.info('WebSocket适配器正在停止') + + async def _process_image_components(self, message_chain_obj: list): + """ + 处理消息链中的图片组件,将path转换为base64 + + Args: + message_chain_obj: 消息链对象列表 + """ + import base64 + + storage_mgr = self.ap.storage_mgr + + for component in message_chain_obj: + if component.get('type') == 'Image' and component.get('path'): + try: + # 从storage读取文件 + file_content = await storage_mgr.storage_provider.load(component['path']) + + # 转换为base64 + base64_str = base64.b64encode(file_content).decode('utf-8') + + # 添加data URI前缀(根据文件扩展名判断MIME类型) + file_key = component['path'] + if file_key.lower().endswith(('.jpg', '.jpeg')): + mime_type = 'image/jpeg' + elif file_key.lower().endswith('.png'): + mime_type = 'image/png' + elif file_key.lower().endswith('.gif'): + mime_type = 'image/gif' + elif file_key.lower().endswith('.webp'): + mime_type = 'image/webp' + else: + mime_type = 'image/png' # 默认 + + component['base64'] = f'data:{mime_type};base64,{base64_str}' + await storage_mgr.storage_provider.delete(component['path']) + component['path'] = '' + # 保留path字段用于后端处理,前端使用base64显示 + except Exception as e: + await self.logger.error(f'加载图片文件失败 {component["path"]}: {e}') + + async def handle_websocket_message( + self, + connection: WebSocketConnection, + message_data: dict, + ): + """ + 处理从WebSocket接收的消息 + + 这个方法只负责接收消息、保存到历史记录、并触发事件处理 + 不等待任何响应,响应消息会通过reply_message/reply_message_chunk直接发送 + + Args: + connection: WebSocket连接对象 + message_data: 消息数据 + """ + pipeline_uuid = connection.pipeline_uuid + session_type = connection.session_type + + # 选择会话 + use_session = self.websocket_group_session if session_type == 'group' else self.websocket_person_session + + # 解析消息链 + message_chain_obj = message_data.get('message', []) + + # 处理图片组件:将path转换为base64 + await self._process_image_components(message_chain_obj) + + message_chain = platform_message.MessageChain.model_validate(message_chain_obj) + + # 生成消息ID + message_id = len(use_session.get_message_list(pipeline_uuid)) + 1 + + # 保存用户消息 + user_message = WebSocketMessage( + id=message_id, + role='user', + content=str(message_chain), + message_chain=message_chain_obj, + timestamp=datetime.now().isoformat(), + connection_id=connection.connection_id, + is_final=True, # 用户消息始终是完整的,非流式 + ) + use_session.get_message_list(pipeline_uuid).append(user_message) + + # 广播用户消息到所有连接(包括发送者) + await ws_connection_manager.broadcast_to_pipeline( + pipeline_uuid, + { + 'type': 'user_message', + 'data': user_message.model_dump(), + }, + ) + + # 添加消息源 + message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp())) + + # 创建事件 + if session_type == 'person': + sender = platform_entities.Friend( + id=f'websocket_{connection.connection_id}', nickname='User', remark='User' + ) + event = platform_events.FriendMessage( + sender=sender, message_chain=message_chain, time=datetime.now().timestamp() + ) + else: + group = platform_entities.Group( + id='websocketgroup', name='Group', permission=platform_entities.Permission.Member + ) + sender = platform_entities.GroupMember( + id=f'websocket_{connection.connection_id}', + member_name='User', + group=group, + permission=platform_entities.Permission.Member, + ) + event = platform_events.GroupMessage( + sender=sender, message_chain=message_chain, time=datetime.now().timestamp() + ) + + # 设置流水线UUID + self.ap.platform_mgr.websocket_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid + + # 异步触发事件处理(不等待结果) + if event.__class__ in self.listeners: + asyncio.create_task(self.listeners[event.__class__](event, self)) + + def get_websocket_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]: + """获取消息历史""" + if session_type == 'person': + return [message.model_dump() for message in self.websocket_person_session.get_message_list(pipeline_uuid)] + else: + return [message.model_dump() for message in self.websocket_group_session.get_message_list(pipeline_uuid)] + + def reset_session(self, pipeline_uuid: str, session_type: str): + """重置会话""" + if session_type == 'person': + if pipeline_uuid in self.websocket_person_session.message_lists: + self.websocket_person_session.message_lists[pipeline_uuid] = [] + else: + if pipeline_uuid in self.websocket_group_session.message_lists: + self.websocket_group_session.message_lists[pipeline_uuid] = [] diff --git a/src/langbot/pkg/platform/sources/websocket_manager.py b/src/langbot/pkg/platform/sources/websocket_manager.py new file mode 100644 index 00000000..767c5be8 --- /dev/null +++ b/src/langbot/pkg/platform/sources/websocket_manager.py @@ -0,0 +1,177 @@ +"""WebSocket连接管理器 - 管理多个并发WebSocket连接""" + +import asyncio +import logging +import typing +import uuid +from datetime import datetime + +import pydantic + +logger = logging.getLogger(__name__) + + +class WebSocketConnection(pydantic.BaseModel): + """单个WebSocket连接""" + + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + connection_id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())) + """连接唯一ID""" + + pipeline_uuid: str + """关联的流水线UUID""" + + session_type: str # 'person' or 'group' + """会话类型""" + + websocket: typing.Any = pydantic.Field(exclude=True) + """WebSocket连接对象 (quart.websocket)""" + + created_at: datetime = pydantic.Field(default_factory=datetime.now) + """连接创建时间""" + + last_active: datetime = pydantic.Field(default_factory=datetime.now) + """最后活跃时间""" + + send_queue: asyncio.Queue = pydantic.Field(default_factory=asyncio.Queue, exclude=True) + """发送消息队列""" + + is_active: bool = True + """连接是否活跃""" + + metadata: dict = pydantic.Field(default_factory=dict) + """连接元数据(可存储额外信息)""" + + +class WebSocketConnectionManager: + """WebSocket连接管理器 - 支持多连接并发""" + + def __init__(self): + self.connections: dict[str, WebSocketConnection] = {} + """所有活跃连接 {connection_id: connection}""" + + self.pipeline_connections: dict[str, set[str]] = {} + """流水线到连接的映射 {pipeline_uuid: {connection_id, ...}}""" + + self.session_connections: dict[str, set[str]] = {} + """会话类型到连接的映射 {session_type: {connection_id, ...}}""" + + self._lock = asyncio.Lock() + """线程锁,保护并发访问""" + + async def add_connection( + self, + websocket: typing.Any, + pipeline_uuid: str, + session_type: str, + metadata: dict = None, + ) -> WebSocketConnection: + """添加新的WebSocket连接""" + async with self._lock: + connection = WebSocketConnection( + pipeline_uuid=pipeline_uuid, + session_type=session_type, + websocket=websocket, + metadata=metadata or {}, + ) + + self.connections[connection.connection_id] = connection + + # 更新流水线映射 + if pipeline_uuid not in self.pipeline_connections: + self.pipeline_connections[pipeline_uuid] = set() + self.pipeline_connections[pipeline_uuid].add(connection.connection_id) + + # 更新会话类型映射 + if session_type not in self.session_connections: + self.session_connections[session_type] = set() + self.session_connections[session_type].add(connection.connection_id) + + logger.debug( + f'WebSocket connection established: {connection.connection_id} ' + f'(pipeline={pipeline_uuid}, session_type={session_type})' + ) + + return connection + + async def remove_connection(self, connection_id: str): + """移除WebSocket连接""" + async with self._lock: + if connection_id not in self.connections: + return + + connection = self.connections[connection_id] + connection.is_active = False + + # 从流水线映射中移除 + if connection.pipeline_uuid in self.pipeline_connections: + self.pipeline_connections[connection.pipeline_uuid].discard(connection_id) + if not self.pipeline_connections[connection.pipeline_uuid]: + del self.pipeline_connections[connection.pipeline_uuid] + + # 从会话类型映射中移除 + if connection.session_type in self.session_connections: + self.session_connections[connection.session_type].discard(connection_id) + if not self.session_connections[connection.session_type]: + del self.session_connections[connection.session_type] + + del self.connections[connection_id] + + logger.debug(f'WebSocket connection disconnected: {connection_id}') + + async def get_connection(self, connection_id: str) -> typing.Optional[WebSocketConnection]: + """获取指定连接""" + return self.connections.get(connection_id) + + async def get_connections_by_pipeline(self, pipeline_uuid: str) -> list[WebSocketConnection]: + """获取指定流水线的所有连接""" + connection_ids = self.pipeline_connections.get(pipeline_uuid, set()) + return [self.connections[cid] for cid in connection_ids if cid in self.connections] + + async def get_connections_by_session_type(self, session_type: str) -> list[WebSocketConnection]: + """获取指定会话类型的所有连接""" + connection_ids = self.session_connections.get(session_type, set()) + return [self.connections[cid] for cid in connection_ids if cid in self.connections] + + async def broadcast_to_pipeline(self, pipeline_uuid: str, message: dict): + """向指定流水线的所有连接广播消息""" + connections = await self.get_connections_by_pipeline(pipeline_uuid) + tasks = [] + for conn in connections: + tasks.append(self.send_to_connection(conn.connection_id, message)) + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + async def send_to_connection(self, connection_id: str, message: dict): + """向指定连接发送消息""" + connection = await self.get_connection(connection_id) + if not connection or not connection.is_active: + logger.warning(f'Attempt to send message to invalid connection: {connection_id}') + return + + try: + await connection.send_queue.put(message) + connection.last_active = datetime.now() + except Exception as e: + logger.error(f'Failed to send message to connection {connection_id}: {e}') + await self.remove_connection(connection_id) + + async def update_activity(self, connection_id: str): + """更新连接活跃时间""" + connection = await self.get_connection(connection_id) + if connection: + connection.last_active = datetime.now() + + def get_stats(self) -> dict: + """获取连接统计信息""" + return { + 'total_connections': len(self.connections), + 'pipelines': len(self.pipeline_connections), + 'connections_by_pipeline': {k: len(v) for k, v in self.pipeline_connections.items()}, + 'connections_by_session_type': {k: len(v) for k, v in self.session_connections.items()}, + } + + +# 全局连接管理器实例 +ws_connection_manager = WebSocketConnectionManager() diff --git a/web/src/app/home/bots/components/bot-form/BotForm.tsx b/web/src/app/home/bots/components/bot-form/BotForm.tsx index bbc7fb69..4288599f 100644 --- a/web/src/app/home/bots/components/bot-form/BotForm.tsx +++ b/web/src/app/home/bots/components/bot-form/BotForm.tsx @@ -129,7 +129,6 @@ export default function BotForm({ form.setValue('adapter_config', val.adapter_config); form.setValue('enable', val.enable); form.setValue('use_pipeline_uuid', val.use_pipeline_uuid || ''); - console.log('form', form.getValues()); handleAdapterSelect(val.adapter); // dynamicForm.setFieldsValue(val.adapter_config); }) @@ -145,7 +144,6 @@ export default function BotForm({ async function initBotFormComponent() { // 初始化流水线列表 const pipelinesRes = await httpClient.getPipelines(); - console.log('rawPipelineList', pipelinesRes); setPipelineNameList( pipelinesRes.pipelines.map((item) => { return { @@ -157,7 +155,6 @@ export default function BotForm({ // 拉取adapter const adaptersRes = await httpClient.getAdapters(); - console.log('rawAdapterList', adaptersRes); setAdapterNameList( adaptersRes.adapters.map((item) => { return { @@ -253,12 +250,10 @@ export default function BotForm({ } // 只有通过外层固定表单验证才会走到这里,真正的提交逻辑在这里 - function onDynamicFormSubmit(value: object) { + function onDynamicFormSubmit() { setIsLoading(true); - console.log('set loading', true); if (initBotId) { // 编辑提交 - // console.log('submit edit', form.getFieldsValue(), value); const updateBot: Bot = { uuid: initBotId, name: form.getValues().name, @@ -270,8 +265,7 @@ export default function BotForm({ }; httpClient .updateBot(initBotId, updateBot) - .then((res) => { - console.log('update bot success', res); + .then(() => { onFormSubmit(form.getValues()); toast.success(t('bots.saveSuccess')); }) @@ -285,7 +279,6 @@ export default function BotForm({ }); } else { // 创建提交 - console.log('submit create', form.getValues(), value); const newBot: Bot = { name: form.getValues().name, description: form.getValues().description, @@ -295,7 +288,6 @@ export default function BotForm({ httpClient .createBot(newBot) .then((res) => { - console.log('create bot success', res); toast.success(t('bots.createSuccess')); initBotId = res.uuid; diff --git a/web/src/app/home/bots/page.tsx b/web/src/app/home/bots/page.tsx index 59dc83c3..b3c0dd25 100644 --- a/web/src/app/home/bots/page.tsx +++ b/web/src/app/home/bots/page.tsx @@ -86,7 +86,6 @@ export default function BotConfigPage() { } function handleNewBotCreated(botId: string) { - console.log('new bot created', botId); getBotList(); setSelectedBotId(botId); } diff --git a/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx b/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx index dd2178f2..f46d42b5 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx +++ b/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx @@ -110,8 +110,6 @@ export default function DynamicFormComponent({ // 当 initialValues 变化时更新表单值 // 但要避免因为内部表单更新触发的 onSubmit 导致的 initialValues 变化而重新设置表单 useEffect(() => { - console.log('initialValues', initialValues); - // 首次挂载时,使用 initialValues 初始化表单 if (isInitialMount.current) { isInitialMount.current = false; @@ -148,7 +146,6 @@ export default function DynamicFormComponent({ const subscription = form.watch(() => { // 获取完整的表单值,确保包含所有默认值 const formValues = form.getValues(); - console.log('formValues', formValues); const finalValues = itemConfigList.reduce( (acc, item) => { acc[item.name] = formValues[item.name] ?? item.default; @@ -156,7 +153,6 @@ export default function DynamicFormComponent({ }, {} as Record, ); - console.log('finalValues', finalValues); onSubmit?.(finalValues); }); return () => subscription.unsubscribe(); diff --git a/web/src/app/home/components/home-sidebar/HomeSidebar.tsx b/web/src/app/home/components/home-sidebar/HomeSidebar.tsx index 100feb00..00575a77 100644 --- a/web/src/app/home/components/home-sidebar/HomeSidebar.tsx +++ b/web/src/app/home/components/home-sidebar/HomeSidebar.tsx @@ -66,7 +66,6 @@ export default function HomeSidebar({ .catch((error) => { console.error('Failed to fetch GitHub star count:', error); }); - return () => console.log('sidebar.unmounted'); }, []); function handleChildClick(child: SidebarChildVO) { @@ -90,7 +89,6 @@ export default function HomeSidebar({ } function handleRoute(child: SidebarChildVO) { - console.log(child); router.push(`${child.route}`); } @@ -102,7 +100,6 @@ export default function HomeSidebar({ routeList[1] === 'home' && sidebarConfigList.find((childConfig) => childConfig.route === pathname) ) { - console.log('find success'); const routeSelectChild = sidebarConfigList.find( (childConfig) => childConfig.route === pathname, ); @@ -144,7 +141,6 @@ export default function HomeSidebar({
{ - console.log('click:', config.id); handleChildClick(config); }} > diff --git a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx index 13276e8f..6ed5173c 100644 --- a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx +++ b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx @@ -103,8 +103,6 @@ export default function KBForm({ }; const onSubmit = (data: z.infer) => { - console.log('data', data); - if (initKbId) { // update knowledge base const updateKb: KnowledgeBase = { @@ -116,7 +114,6 @@ export default function KBForm({ httpClient .updateKnowledgeBase(initKbId, updateKb) .then((res) => { - console.log('update knowledge base success', res); onKbUpdated(res.uuid); toast.success(t('knowledge.updateKnowledgeBaseSuccess')); }) @@ -135,7 +132,6 @@ export default function KBForm({ httpClient .createKnowledgeBase(newKb) .then((res) => { - console.log('create knowledge base success', res); onNewKbCreated(res.uuid); }) .catch((err) => { @@ -200,7 +196,6 @@ export default function KBForm({ disabled={!!initKbId} onValueChange={(value) => { field.onChange(value); - console.log('value', value); }} value={field.value} > diff --git a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx index ec9cac6c..c52b7b65 100644 --- a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx +++ b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx @@ -326,8 +326,7 @@ export default function EmbeddingForm({ api_keys: apiKey ? [apiKey] : [], extra_args: extraArgsObj, }) - .then((res) => { - console.log(res); + .then(() => { toast.success(t('models.testSuccess')); }) .catch(() => { diff --git a/web/src/app/home/models/component/llm-form/LLMForm.tsx b/web/src/app/home/models/component/llm-form/LLMForm.tsx index a20d6745..b2c18d3e 100644 --- a/web/src/app/home/models/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/models/component/llm-form/LLMForm.tsx @@ -341,8 +341,7 @@ export default function LLMForm({ abilities: form.getValues('abilities'), extra_args: extraArgsObj, }) - .then((res) => { - console.log(res); + .then(() => { toast.success(t('models.testSuccess')); }) .catch(() => { diff --git a/web/src/app/home/models/page.tsx b/web/src/app/home/models/page.tsx index 9a3a1597..7c33918e 100644 --- a/web/src/app/home/models/page.tsx +++ b/web/src/app/home/models/page.tsx @@ -54,7 +54,6 @@ export default function LLMConfigPage() { .getProviderLLMModels() .then((resp) => { const llmModelList: LLMCardVO[] = resp.models.map((model: LLMModel) => { - console.log('model', model); return new LLMCardVO({ id: model.uuid, iconURL: httpClient.getProviderRequesterIconURL(model.requester), @@ -66,7 +65,6 @@ export default function LLMConfigPage() { abilities: model.abilities || [], }); }); - console.log('get llmModelList', llmModelList); setCardList(llmModelList); }) .catch((err) => { @@ -78,7 +76,6 @@ export default function LLMConfigPage() { function selectLLM(cardVO: LLMCardVO) { setIsEditForm(true); setNowSelectedLLM(cardVO); - console.log('set now vo', cardVO); setModalOpen(true); } function handleCreateModelClick() { diff --git a/web/src/app/home/pipelines/PipelineDetailDialog.tsx b/web/src/app/home/pipelines/PipelineDetailDialog.tsx index 35a780fc..c51ad961 100644 --- a/web/src/app/home/pipelines/PipelineDetailDialog.tsx +++ b/web/src/app/home/pipelines/PipelineDetailDialog.tsx @@ -49,6 +49,7 @@ export default function PipelineDialog({ propPipelineId, ); const [currentMode, setCurrentMode] = useState('config'); + const [isWebSocketConnected, setIsWebSocketConnected] = useState(false); useEffect(() => { setPipelineId(propPipelineId); @@ -184,10 +185,29 @@ export default function PipelineDialog({
{getDialogTitle()} + {currentMode === 'debug' && ( +
+
+ + {isWebSocketConnected + ? t('pipelines.debugDialog.connected') + : t('pipelines.debugDialog.disconnected')} + +
+ )}
)}
diff --git a/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx b/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx index 71e5b748..f5432a16 100644 --- a/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx +++ b/web/src/app/home/pipelines/components/debug-dialog/DebugDialog.tsx @@ -4,30 +4,32 @@ import { httpClient } from '@/app/infra/http/HttpClient'; import { DialogContent } from '@/components/ui/dialog'; import { Button } from '@/components/ui/button'; import { Input } from '@/components/ui/input'; - import { ScrollArea } from '@/components/ui/scroll-area'; import { cn } from '@/lib/utils'; -import { Message } from '@/app/infra/entities/message'; +import { + Message, + MessageChainComponent, + Image, + Plain, + At, +} from '@/app/infra/entities/message'; import { toast } from 'sonner'; import AtBadge from './AtBadge'; -import { Switch } from '@/components/ui/switch'; - -interface MessageComponent { - type: 'At' | 'Plain'; - target?: string; - text?: string; -} +import { WebSocketClient } from '@/app/infra/websocket/WebSocketClient'; +import ImagePreviewDialog from './ImagePreviewDialog'; interface DebugDialogProps { open: boolean; pipelineId: string; isEmbedded?: boolean; + onConnectionStatusChange?: (isConnected: boolean) => void; } export default function DebugDialog({ open, pipelineId, isEmbedded = false, + onConnectionStatusChange, }: DebugDialogProps) { const { t } = useTranslation(); const [selectedPipelineId, setSelectedPipelineId] = useState(pipelineId); @@ -37,10 +39,19 @@ export default function DebugDialog({ const [showAtPopover, setShowAtPopover] = useState(false); const [hasAt, setHasAt] = useState(false); const [isHovering, setIsHovering] = useState(false); - const [isStreaming, setIsStreaming] = useState(true); + const [isConnected, setIsConnected] = useState(false); + const [selectedImages, setSelectedImages] = useState< + Array<{ file: File; preview: string; fileKey?: string }> + >([]); + const [isUploading, setIsUploading] = useState(false); + const [previewImageUrl, setPreviewImageUrl] = useState(''); + const [showImagePreview, setShowImagePreview] = useState(false); const messagesEndRef = useRef(null); const inputRef = useRef(null); const popoverRef = useRef(null); + const fileInputRef = useRef(null); + const wsClientRef = useRef(null); + const isInitializingRef = useRef(false); const scrollToBottom = useCallback(() => { // 使用setTimeout确保在DOM更新后执行滚动 @@ -60,7 +71,7 @@ export default function DebugDialog({ const loadMessages = useCallback( async (pipelineId: string) => { try { - const response = await httpClient.getWebChatHistoryMessages( + const response = await httpClient.getWebSocketHistoryMessages( pipelineId, sessionType, ); @@ -71,23 +82,123 @@ export default function DebugDialog({ }, [sessionType], ); + + // 初始化WebSocket连接 + const initWebSocket = useCallback( + async (pipelineId: string) => { + // 防止重复初始化 + if (isInitializingRef.current) { + return; + } + + try { + isInitializingRef.current = true; + + // 断开旧连接 + if (wsClientRef.current) { + wsClientRef.current.disconnect(); + wsClientRef.current = null; + } + + // 创建新连接 + const wsClient = new WebSocketClient(pipelineId, sessionType); + + wsClient + .onConnected(() => { + setIsConnected(true); + isInitializingRef.current = false; + }) + .onMessage((wsMessage) => { + // 将 WebSocketMessage 转换为 Message 类型 + const message: Message = { + ...wsMessage, + message_chain: wsMessage.message_chain as MessageChainComponent[], + }; + + setMessages((prevMessages) => { + // 查找是否已存在相同ID的消息 + const existingIndex = prevMessages.findIndex( + (m) => m.id === message.id, + ); + + if (existingIndex >= 0) { + // 更新已存在的消息(流式输出) + const newMessages = [...prevMessages]; + newMessages[existingIndex] = message; + return newMessages; + } else { + // 添加新消息 + return [...prevMessages, message]; + } + }); + }) + .onError((error) => { + console.error('WebSocket错误:', error); + setIsConnected(false); + isInitializingRef.current = false; + toast.error(t('pipelines.debugDialog.connectionError')); + }) + .onClose(() => { + setIsConnected(false); + isInitializingRef.current = false; + }) + .onBroadcast((message) => { + toast.info(message); + }); + + await wsClient.connect(); + wsClientRef.current = wsClient; + } catch (error) { + console.error('WebSocket连接失败:', error); + setIsConnected(false); + isInitializingRef.current = false; + toast.error(t('pipelines.debugDialog.connectionFailed')); + } + }, + [sessionType, t], + ); + // 在useEffect中监听messages变化时滚动 useEffect(() => { scrollToBottom(); }, [messages, scrollToBottom]); + // 监听 open 和 pipelineId 变化,进入时连接,离开时断开 useEffect(() => { if (open) { setSelectedPipelineId(pipelineId); - loadMessages(pipelineId); + } else { + // 关闭对话框时立即断开WebSocket + if (wsClientRef.current) { + wsClientRef.current.disconnect(); + wsClientRef.current = null; + setIsConnected(false); + isInitializingRef.current = false; + } } + + return () => { + // 组件卸载时断开WebSocket + if (wsClientRef.current) { + wsClientRef.current.disconnect(); + wsClientRef.current = null; + isInitializingRef.current = false; + } + }; }, [open, pipelineId]); + // 监听 sessionType 和 selectedPipelineId 变化,重新加载消息和连接 useEffect(() => { if (open) { loadMessages(selectedPipelineId); + initWebSocket(selectedPipelineId); } - }, [sessionType, selectedPipelineId, open, loadMessages]); + }, [sessionType, selectedPipelineId, open, loadMessages, initWebSocket]); + + // 通知父组件连接状态变化 + useEffect(() => { + onConnectionStatusChange?.(isConnected); + }, [isConnected, onConnectionStatusChange]); useEffect(() => { const handleClickOutside = (event: MouseEvent) => { @@ -147,10 +258,42 @@ export default function DebugDialog({ } }; + const handleImageSelect = async (e: React.ChangeEvent) => { + const files = e.target.files; + if (!files || files.length === 0) return; + + const newImages: Array<{ file: File; preview: string }> = []; + + for (let i = 0; i < files.length; i++) { + const file = files[i]; + if (file.type.startsWith('image/')) { + const preview = URL.createObjectURL(file); + newImages.push({ file, preview }); + } + } + + setSelectedImages((prev) => [...prev, ...newImages]); + }; + + const handleRemoveImage = (index: number) => { + setSelectedImages((prev) => { + const newImages = [...prev]; + URL.revokeObjectURL(newImages[index].preview); + newImages.splice(index, 1); + return newImages; + }); + }; + const sendMessage = async () => { - if (!inputValue.trim() && !hasAt) return; + if (!inputValue.trim() && !hasAt && selectedImages.length === 0) return; + if (!isConnected || !wsClientRef.current) { + toast.error(t('pipelines.debugDialog.notConnected')); + return; + } try { + setIsUploading(true); + const messageChain = []; let text_content = inputValue.trim(); @@ -161,142 +304,133 @@ export default function DebugDialog({ if (hasAt) { messageChain.push({ type: 'At', - target: 'webchatbot', + target: 'websocketbot', + display: 'websocketbot', }); } - messageChain.push({ - type: 'Plain', - text: text_content, - }); - if (hasAt) { - // for showing - text_content = '@webchatbot' + text_content; + // 添加文本 + if (text_content) { + messageChain.push({ + type: 'Plain', + text: text_content, + }); } - const userMessage: Message = { - id: -1, - role: 'user', - content: text_content, - timestamp: new Date().toISOString(), - message_chain: messageChain, - }; - // 根据isStreaming状态决定使用哪种传输方式 - if (isStreaming) { - // streaming - // 创建初始bot消息 - const placeholderRandomId = Math.floor(Math.random() * 1000000); - const botMessagePlaceholder: Message = { - id: placeholderRandomId, - role: 'assistant', - content: 'Generating...', - timestamp: new Date().toISOString(), - message_chain: [{ type: 'Plain', text: 'Generating...' }], - }; - - // 添加用户消息和初始bot消息到状态 - - setMessages((prevMessages) => [ - ...prevMessages, - userMessage, - botMessagePlaceholder, - ]); - setInputValue(''); - setHasAt(false); + // 上传图片并添加到消息链 + for (const image of selectedImages) { try { - await httpClient.sendStreamingWebChatMessage( - sessionType, - messageChain, + const result = await httpClient.uploadWebSocketImage( selectedPipelineId, - (data) => { - // 处理流式响应数据 - console.log('data', data); - if (data.message) { - // 更新完整内容 - - setMessages((prevMessages) => { - const updatedMessages = [...prevMessages]; - const botMessageIndex = updatedMessages.findIndex( - (message) => message.id === placeholderRandomId, - ); - if (botMessageIndex !== -1) { - updatedMessages[botMessageIndex] = { - ...updatedMessages[botMessageIndex], - content: data.message.content, - message_chain: [ - { type: 'Plain', text: data.message.content }, - ], - }; - } - return updatedMessages; - }); - } - }, - () => {}, - (error) => { - // 处理错误 - console.error('Streaming error:', error); - if (sessionType === 'person') { - toast.error(t('pipelines.debugDialog.sendFailed')); - } - }, + image.file, ); + messageChain.push({ + type: 'Image', + path: result.file_key, + }); } catch (error) { - console.error('Failed to send streaming message:', error); - if (sessionType === 'person') { - toast.error(t('pipelines.debugDialog.sendFailed')); - } + console.error('图片上传失败:', error); + toast.error(t('pipelines.debugDialog.imageUploadFailed')); } - } else { - // non-streaming - setMessages((prevMessages) => [...prevMessages, userMessage]); - setInputValue(''); - setHasAt(false); + } - const response = await httpClient.sendWebChatMessage( - sessionType, - messageChain, - selectedPipelineId, - 180000, + // 清空输入框和图片 + setInputValue(''); + setHasAt(false); + selectedImages.forEach((img) => URL.revokeObjectURL(img.preview)); + setSelectedImages([]); + + // 通过WebSocket发送消息 + // 不在本地添加消息,等待后端广播回来(带有正确的ID) + wsClientRef.current.sendMessage(messageChain); + } catch (error) { + console.error('Failed to send message:', error); + toast.error(t('pipelines.debugDialog.sendFailed')); + } finally { + setIsUploading(false); + inputRef.current?.focus(); + } + }; + + const renderMessageComponent = ( + component: MessageChainComponent, + index: number, + ) => { + switch (component.type) { + case 'Plain': + return {(component as Plain).text}; + + case 'At': { + const atComponent = component as At; + // 优先使用 display,如果没有则使用 target + const displayName = + atComponent.display || atComponent.target?.toString() || ''; + return ( + + + + ); + } + + case 'AtAll': + return ( + + + ); - setMessages((prevMessages) => [...prevMessages, response.message]); - } - } catch ( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - error: any - ) { - console.log(error, 'type of error', typeof error); - console.error('Failed to send message:', error); + case 'Image': { + const img = component as Image; + const imageUrl = img.url || (img.base64 ? img.base64 : ''); - if (!error.message.includes('timeout') && sessionType === 'person') { - toast.error(t('pipelines.debugDialog.sendFailed')); + if (!imageUrl) return null; + + return ( +
+ Image { + setPreviewImageUrl(imageUrl); + setShowImagePreview(true); + }} + /> +
+ ); } - } finally { - inputRef.current?.focus(); + + case 'File': { + const file = component as MessageChainComponent & { name?: string }; + return ( +
+ + + + [文件] {file.name || 'Unknown'} +
+ ); + } + + case 'Voice': + return [语音]; + + case 'Source': + // Source 不显示 + return null; + + default: + return [{component.type}]; } }; const renderMessageContent = (message: Message) => { return ( - - {(message.message_chain as MessageComponent[]).map( - (component, index) => { - if (component.type === 'At') { - return ( - - ); - } else if (component.type === 'Plain') { - return {component.text}; - } - return null; - }, +
+ {message.message_chain.map((component, index) => + renderMessageComponent(component, index), )} - +
); }; @@ -341,11 +475,10 @@ export default function DebugDialog({ -
- +
{messages.length === 0 ? (
@@ -389,16 +522,65 @@ export default function DebugDialog({
+ {/* 图片预览区域 */} + {selectedImages.length > 0 && ( +
+
+ {selectedImages.map((image, index) => ( +
+ {`preview-${index}`} + +
+ ))} +
+
+ )} +
-
- - {t('pipelines.debugDialog.streaming')} - - +
+ +
{hasAt && ( - + )}
{showAtPopover && (
setIsHovering(false)} > - @webchatbot - {t('pipelines.debugDialog.atTips')} + @websocketbot - {t('pipelines.debugDialog.atTips')}
@@ -440,10 +623,14 @@ export default function DebugDialog({
@@ -453,16 +640,30 @@ export default function DebugDialog({ // 如果是嵌入模式,直接返回内容 if (isEmbedded) { return ( -
-
{renderContent()}
-
+ <> +
+
{renderContent()}
+
+ setShowImagePreview(false)} + /> + ); } // 原有的Dialog包装 return ( - - {renderContent()} - + <> + + {renderContent()} + + setShowImagePreview(false)} + /> + ); } diff --git a/web/src/app/home/pipelines/components/debug-dialog/ImagePreviewDialog.tsx b/web/src/app/home/pipelines/components/debug-dialog/ImagePreviewDialog.tsx new file mode 100644 index 00000000..89eb9656 --- /dev/null +++ b/web/src/app/home/pipelines/components/debug-dialog/ImagePreviewDialog.tsx @@ -0,0 +1,56 @@ +import React from 'react'; + +interface ImagePreviewDialogProps { + open: boolean; + imageUrl: string; + onClose: () => void; +} + +export default function ImagePreviewDialog({ + open, + imageUrl, + onClose, +}: ImagePreviewDialogProps) { + if (!open) return null; + + return ( +
+ {/* 背景遮罩 */} +
+ + {/* 内容区域 */} +
+ {/* 关闭按钮 - 在图片上方 */} + + + {/* 图片 */} + Preview e.stopPropagation()} + /> +
+
+ ); +} diff --git a/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx b/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx index a262db32..7de82e04 100644 --- a/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx +++ b/web/src/app/home/pipelines/components/pipeline-form/PipelineFormComponent.tsx @@ -159,7 +159,6 @@ export default function PipelineFormComponent({ }, [form, isEditMode]); function handleFormSubmit(values: FormValues) { - console.log('handleFormSubmit', values); if (isEditMode) { handleModify(values); } else { @@ -168,7 +167,6 @@ export default function PipelineFormComponent({ } function handleCreate(values: FormValues) { - console.log('handleCreate', values); const pipeline: Pipeline = { config: {}, description: values.basic.description, diff --git a/web/src/app/home/pipelines/page.tsx b/web/src/app/home/pipelines/page.tsx index bb20abfa..81c25363 100644 --- a/web/src/app/home/pipelines/page.tsx +++ b/web/src/app/home/pipelines/page.tsx @@ -75,7 +75,6 @@ export default function PluginConfigPage() { setPipelineList(pipelineList); }) .catch((error) => { - console.log(error); toast.error(t('pipelines.getPipelineListError') + error.message); }); } diff --git a/web/src/app/home/plugins/components/plugin-market/PluginMarketComponent.tsx b/web/src/app/home/plugins/components/plugin-market/PluginMarketComponent.tsx index e74147c8..e4303882 100644 --- a/web/src/app/home/plugins/components/plugin-market/PluginMarketComponent.tsx +++ b/web/src/app/home/plugins/components/plugin-market/PluginMarketComponent.tsx @@ -329,7 +329,6 @@ function MarketPageContent({ // 安装插件 // const handleInstallPlugin = (plugin: PluginV4) => { - // console.log('install plugin', plugin); // }; return ( diff --git a/web/src/app/home/plugins/components/plugin-sort/PluginSortDialog.tsx b/web/src/app/home/plugins/components/plugin-sort/PluginSortDialog.tsx index 998ae93e..1c601e77 100644 --- a/web/src/app/home/plugins/components/plugin-sort/PluginSortDialog.tsx +++ b/web/src/app/home/plugins/components/plugin-sort/PluginSortDialog.tsx @@ -123,7 +123,6 @@ // function handleDragEnd(event: DragEndEvent) { // const { active, over } = event; -// console.log('Drag end event:', { active, over }); // if (over && active.id !== over.id) { // setSortedPlugins((items) => { diff --git a/web/src/app/home/plugins/page.tsx b/web/src/app/home/plugins/page.tsx index f9b73fac..e589f728 100644 --- a/web/src/app/home/plugins/page.tsx +++ b/web/src/app/home/plugins/page.tsx @@ -266,7 +266,6 @@ export default function PluginConfigPage() { watchTask(taskId); }) .catch((err) => { - console.log('error when install plugin:', err); setInstallError(err.message); setPluginInstallStatus(PluginInstallStatus.ERROR); }); @@ -278,7 +277,6 @@ export default function PluginConfigPage() { watchTask(taskId); }) .catch((err) => { - console.log('error when install plugin:', err); setInstallError(err.message); setPluginInstallStatus(PluginInstallStatus.ERROR); }); @@ -431,7 +429,9 @@ export default function PluginConfigPage() { return (
{ - return this.post( - `/api/v1/pipelines/${pipelineId}/chat/send`, - { - session_type: sessionType, - message: messageChain, - }, - { - timeout, - }, - ); - } - - public async sendStreamingWebChatMessage( - sessionType: string, - messageChain: object[], - pipelineId: string, - onMessage: (data: ApiRespWebChatMessage) => void, - onComplete: () => void, - onError: (error: Error) => void, - ): Promise { - try { - // 构造完整的URL,处理相对路径的情况 - let url = `${this.baseURL}/api/v1/pipelines/${pipelineId}/chat/send`; - if (this.baseURL === '/') { - // 获取用户访问的完整URL - const baseURL = window.location.origin; - url = `${baseURL}/api/v1/pipelines/${pipelineId}/chat/send`; - } - - // 使用fetch发送流式请求,因为axios在浏览器环境中不直接支持流式响应 - const response = await fetch(url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${this.getSessionSync()}`, - }, - body: JSON.stringify({ - session_type: sessionType, - message: messageChain, - is_stream: true, - }), - }); - - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); - } - - if (!response.body) { - throw new Error('ReadableStream not supported'); - } - - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ''; - - // 读取流式响应 - try { - while (true) { - const { done, value } = await reader.read(); - - if (done) { - onComplete(); - break; - } - - // 解码数据 - buffer += decoder.decode(value, { stream: true }); - - // 处理完整的JSON对象 - const lines = buffer.split('\n\n'); - buffer = lines.pop() || ''; - - for (const line of lines) { - if (line.startsWith('data:')) { - try { - const data = JSON.parse(line.slice(5)); - - if (data.type === 'end') { - // 流传输结束 - reader.cancel(); - onComplete(); - return; - } - if (data.type === 'start') { - console.log(data.type); - } - - if (data.message) { - // 处理消息数据 - onMessage(data); - } - } catch (error) { - console.error('Error parsing streaming data:', error); - } - } - } - } - } finally { - reader.releaseLock(); - } - } catch (error) { - onError(error as Error); - } - } - - public getWebChatHistoryMessages( + // ============ WebSocket Chat API ============ + public getWebSocketHistoryMessages( pipelineId: string, sessionType: string, ): Promise { return this.get( - `/api/v1/pipelines/${pipelineId}/chat/messages/${sessionType}`, + `/api/v1/pipelines/${pipelineId}/ws/messages/${sessionType}`, ); } - public resetWebChatSession( + public async uploadWebSocketImage( + pipelineId: string, + imageFile: File, + ): Promise<{ file_key: string }> { + const formData = new FormData(); + formData.append('file', imageFile); + + return this.postFile(`/api/v1/files/images`, formData); + } + + public resetWebSocketSession( pipelineId: string, sessionType: string, ): Promise<{ message: string }> { - return this.post( - `/api/v1/pipelines/${pipelineId}/chat/reset/${sessionType}`, - ); + return this.post(`/api/v1/pipelines/${pipelineId}/ws/reset/${sessionType}`); + } + + public getWebSocketConnections(pipelineId: string): Promise<{ + stats: { + total_connections: number; + pipelines: number; + connections_by_pipeline: Record; + connections_by_session_type: Record; + }; + connections: Array<{ + connection_id: string; + session_type: string; + created_at: string; + last_active: string; + is_active: boolean; + }>; + }> { + return this.get(`/api/v1/pipelines/${pipelineId}/ws/connections`); + } + + public broadcastWebSocketMessage( + pipelineId: string, + message: string, + ): Promise<{ message: string }> { + return this.post(`/api/v1/pipelines/${pipelineId}/ws/broadcast`, { + message, + }); } // ============ Platform API ============ diff --git a/web/src/app/infra/http/BaseHttpClient.ts b/web/src/app/infra/http/BaseHttpClient.ts index cc2c31e4..11aaf21b 100644 --- a/web/src/app/infra/http/BaseHttpClient.ts +++ b/web/src/app/infra/http/BaseHttpClient.ts @@ -97,8 +97,6 @@ export abstract class BaseHttpClient { switch (status) { case 401: - console.log('401 error: ', errMessage, error.request); - console.log('responseURL', error.request.responseURL); if (typeof window !== 'undefined') { localStorage.removeItem('token'); if (!error.request.responseURL.includes('/check-token')) { diff --git a/web/src/app/infra/websocket/WebSocketClient.ts b/web/src/app/infra/websocket/WebSocketClient.ts new file mode 100644 index 00000000..426ee387 --- /dev/null +++ b/web/src/app/infra/websocket/WebSocketClient.ts @@ -0,0 +1,295 @@ +/** + * WebSocket客户端类 + * 用于管理WebSocket连接和消息处理 + */ +export interface WebSocketMessage { + id: number; + role: 'user' | 'assistant'; + content: string; + message_chain: Array<{ type: string; text?: string; target?: string }>; + timestamp: string; + is_final?: boolean; + connection_id?: string; +} + +export interface WebSocketResponse { + type: + | 'connected' + | 'response' + | 'user_message' + | 'pong' + | 'broadcast' + | 'error'; + connection_id?: string; + pipeline_uuid?: string; + session_type?: string; + timestamp?: string; + data?: WebSocketMessage; + message?: string; +} + +export class WebSocketClient { + private ws: WebSocket | null = null; + private connectionId: string | null = null; + private reconnectAttempts = 0; + private maxReconnectAttempts = 5; + private reconnectDelay = 3000; // 3秒重连间隔 + private heartbeatInterval: NodeJS.Timeout | null = null; + private heartbeatIntervalMs = 30000; // 30秒 + private isConnecting = false; // 防止重复连接 + + // 事件回调 + private onConnectedCallback?: (data: WebSocketResponse) => void; + private onMessageCallback?: (data: WebSocketMessage) => void; + private onErrorCallback?: (error: Error) => void; + private onCloseCallback?: () => void; + private onBroadcastCallback?: (message: string) => void; + + constructor( + private pipelineId: string, + private sessionType: 'person' | 'group' = 'person', + private token?: string, + ) {} + + /** + * 连接到WebSocket服务器 + */ + public connect(): Promise { + return new Promise((resolve, reject) => { + try { + // 防止重复连接 + if ( + this.isConnecting || + (this.ws && this.ws.readyState === WebSocket.CONNECTING) + ) { + console.warn('WebSocket正在连接中,忽略重复连接请求'); + reject(new Error('Connection already in progress')); + return; + } + + // 如果已经连接,直接返回 + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + console.warn('WebSocket已连接,忽略重复连接请求'); + resolve(this.connectionId || ''); + return; + } + + this.isConnecting = true; + + // 构建WebSocket URL + const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + // extract host from process.env.NEXT_PUBLIC_API_BASE_URL + const host = + process.env.NEXT_PUBLIC_API_BASE_URL?.split('://')[1] || ''; + const url = `${protocol}//${host}/api/v1/pipelines/${this.pipelineId}/ws/connect?session_type=${this.sessionType}`; + + this.ws = new WebSocket(url); + + // 连接打开 + this.ws.onopen = () => { + this.reconnectAttempts = 0; + this.isConnecting = false; + this.startHeartbeat(); + }; + + // 接收消息 + this.ws.onmessage = (event) => { + try { + const data: WebSocketResponse = JSON.parse(event.data); + this.handleMessage(data); + + // 第一次连接成功 + if (data.type === 'connected' && data.connection_id) { + this.connectionId = data.connection_id; + resolve(data.connection_id); + } + } catch (error) { + console.error('解析WebSocket消息失败:', error); + this.onErrorCallback?.(error as Error); + } + }; + + // 连接关闭 + this.ws.onclose = () => { + this.isConnecting = false; + this.stopHeartbeat(); + this.onCloseCallback?.(); + + // 自动重连 + if (this.reconnectAttempts < this.maxReconnectAttempts) { + this.reconnectAttempts++; + setTimeout(() => { + this.connect().catch(console.error); + }, this.reconnectDelay * this.reconnectAttempts); + } + }; + + // 连接错误 + this.ws.onerror = (event) => { + console.error('WebSocket错误:', event); + this.isConnecting = false; + const error = new Error('WebSocket连接失败'); + this.onErrorCallback?.(error); + reject(error); + }; + } catch (error) { + this.isConnecting = false; + reject(error); + } + }); + } + + /** + * 处理接收到的消息 + */ + private handleMessage(data: WebSocketResponse) { + switch (data.type) { + case 'connected': + this.onConnectedCallback?.(data); + break; + + case 'response': + if (data.data) { + this.onMessageCallback?.(data.data); + } + break; + + case 'user_message': + // 用户消息广播(包括自己发送的消息) + if (data.data) { + this.onMessageCallback?.(data.data); + } + break; + + case 'pong': + // 心跳响应 + break; + + case 'broadcast': + if (data.message) { + this.onBroadcastCallback?.(data.message); + } + break; + + case 'error': + const error = new Error(data.message || '未知错误'); + this.onErrorCallback?.(error); + break; + + default: + console.warn('未知消息类型:', data); + } + } + + /** + * 发送消息 + */ + public sendMessage( + messageChain: Array<{ type: string; text?: string; target?: string }>, + ) { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + throw new Error('WebSocket未连接'); + } + + const message = { + type: 'message', + message: messageChain, + }; + + this.ws.send(JSON.stringify(message)); + } + + /** + * 发送心跳 + */ + private sendHeartbeat() { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + return; + } + + this.ws.send(JSON.stringify({ type: 'ping' })); + } + + /** + * 启动心跳 + */ + private startHeartbeat() { + this.stopHeartbeat(); + this.heartbeatInterval = setInterval(() => { + this.sendHeartbeat(); + }, this.heartbeatIntervalMs); + } + + /** + * 停止心跳 + */ + private stopHeartbeat() { + if (this.heartbeatInterval) { + clearInterval(this.heartbeatInterval); + this.heartbeatInterval = null; + } + } + + /** + * 断开连接 + */ + public disconnect() { + if (this.ws) { + this.stopHeartbeat(); + + // 停止自动重连 + this.reconnectAttempts = this.maxReconnectAttempts; + + // 发送断开消息 + if (this.ws.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify({ type: 'disconnect' })); + } + + this.ws.close(); + this.ws = null; + this.connectionId = null; + this.isConnecting = false; + } + } + + /** + * 获取连接ID + */ + public getConnectionId(): string | null { + return this.connectionId; + } + + /** + * 获取连接状态 + */ + public isConnected(): boolean { + return this.ws !== null && this.ws.readyState === WebSocket.OPEN; + } + + // ===== 事件回调设置 ===== + + public onConnected(callback: (data: WebSocketResponse) => void) { + this.onConnectedCallback = callback; + return this; + } + + public onMessage(callback: (data: WebSocketMessage) => void) { + this.onMessageCallback = callback; + return this; + } + + public onError(callback: (error: Error) => void) { + this.onErrorCallback = callback; + return this; + } + + public onClose(callback: () => void) { + this.onCloseCallback = callback; + return this; + } + + public onBroadcast(callback: (message: string) => void) { + this.onBroadcastCallback = callback; + return this; + } +} diff --git a/web/src/app/login/page.tsx b/web/src/app/login/page.tsx index 42c1653e..b47e24f4 100644 --- a/web/src/app/login/page.tsx +++ b/web/src/app/login/page.tsx @@ -61,9 +61,7 @@ export default function Login() { router.push('/register'); } }) - .catch((err) => { - console.log('error at getIsInitialized: ', err); - }); + .catch(() => {}); } function checkIfAlreadyLoggedIn() { @@ -75,9 +73,7 @@ export default function Login() { router.push('/home'); } }) - .catch((err) => { - console.log('error at checkIfAlreadyLoggedIn: ', err); - }); + .catch(() => {}); } function onSubmit(values: z.infer>) { handleLogin(values.email, values.password); @@ -89,13 +85,10 @@ export default function Login() { .then((res) => { localStorage.setItem('token', res.token); localStorage.setItem('userEmail', username); - console.log('login success: ', res); router.push('/home'); toast.success(t('common.loginSuccess')); }) - .catch((err) => { - console.log('login error: ', err); - + .catch(() => { toast.error(t('common.loginFailed')); }); } diff --git a/web/src/app/register/page.tsx b/web/src/app/register/page.tsx index cd7edff8..e0ca6aee 100644 --- a/web/src/app/register/page.tsx +++ b/web/src/app/register/page.tsx @@ -59,9 +59,7 @@ export default function Register() { router.push('/login'); } }) - .catch((err) => { - console.log('error at getIsInitialized: ', err); - }); + .catch(() => {}); } function onSubmit(values: z.infer>) { @@ -71,13 +69,11 @@ export default function Register() { function handleRegister(username: string, password: string) { httpClient .initUser(username, password) - .then((res) => { - console.log('init user success: ', res); + .then(() => { toast.success(t('register.initSuccess')); router.push('/login'); }) - .catch((err) => { - console.log('init user error: ', err); + .catch((err: Error) => { toast.error(t('register.initFailed') + err.message); }); } diff --git a/web/src/app/reset-password/page.tsx b/web/src/app/reset-password/page.tsx index 11d7b748..f16a35fa 100644 --- a/web/src/app/reset-password/page.tsx +++ b/web/src/app/reset-password/page.tsx @@ -70,13 +70,11 @@ export default function ResetPassword() { setIsResetting(true); httpClient .resetPassword(email, recoveryKey, newPassword) - .then((res) => { - console.log('reset password success: ', res); + .then(() => { toast.success(t('resetPassword.resetSuccess')); router.push('/login'); }) - .catch((err) => { - console.log('reset password error: ', err); + .catch(() => { toast.error(t('resetPassword.resetFailed')); }) .finally(() => { diff --git a/web/src/i18n/I18nProvider.tsx b/web/src/i18n/I18nProvider.tsx index 91c78b09..8c6b4f28 100644 --- a/web/src/i18n/I18nProvider.tsx +++ b/web/src/i18n/I18nProvider.tsx @@ -23,8 +23,6 @@ export default function I18nProvider({ children }: I18nProviderProps) { export const extractI18nObject = (i18nObject: I18nObject): string => { // 根据当前语言返回对应的值, fallback优先级:en_US、zh_Hans、zh_Hant、ja_JP const language = i18n.language.replace('-', '_'); - console.log('language:', language); - console.log('i18nObject:', i18nObject); if (language === 'en_US' && i18nObject.en_US) return i18nObject.en_US; if (language === 'zh_Hans' && i18nObject.zh_Hans) return i18nObject.zh_Hans; if (language === 'zh_Hant' && i18nObject.zh_Hant) return i18nObject.zh_Hant; diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 352b50ca..6e6b11fb 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -518,6 +518,12 @@ const enUS = { loadPipelinesFailed: 'Failed to load pipelines', atTips: 'Mention the bot', streaming: 'Streaming', + connected: 'WebSocket connected', + disconnected: 'WebSocket disconnected', + connectionError: 'WebSocket connection error', + connectionFailed: 'WebSocket connection failed', + notConnected: 'WebSocket not connected, please try again later', + imageUploadFailed: 'Image upload failed', }, }, knowledge: { diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index dde0e128..d8c2c9b1 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -521,6 +521,13 @@ const jaJP = { loadPipelinesFailed: 'パイプラインの読み込みに失敗しました', atTips: 'ボットをメンション', streaming: 'ストリーミング', + connected: 'WebSocket接続済み', + disconnected: 'WebSocket未接続', + connectionError: 'WebSocket接続エラー', + connectionFailed: 'WebSocket接続に失敗しました', + notConnected: + 'WebSocketに接続されていません。しばらくしてからやり直してください', + imageUploadFailed: '画像のアップロードに失敗しました', }, }, knowledge: { diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 19f02ae0..0d096b5c 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -500,6 +500,12 @@ const zhHans = { loadPipelinesFailed: '加载流水线失败', atTips: '提及机器人', streaming: '流式传输', + connected: 'WebSocket已连接', + disconnected: 'WebSocket未连接', + connectionError: 'WebSocket连接错误', + connectionFailed: 'WebSocket连接失败', + notConnected: 'WebSocket未连接,请稍后重试', + imageUploadFailed: '图片上传失败', }, }, knowledge: { diff --git a/web/src/i18n/locales/zh-Hant.ts b/web/src/i18n/locales/zh-Hant.ts index a74588f4..d1b2bc85 100644 --- a/web/src/i18n/locales/zh-Hant.ts +++ b/web/src/i18n/locales/zh-Hant.ts @@ -497,6 +497,13 @@ const zhHant = { loadMessagesFailed: '載入訊息失敗', loadPipelinesFailed: '載入流程線失敗', atTips: '提及機器人', + streaming: '串流傳輸', + connected: 'WebSocket已連接', + disconnected: 'WebSocket未連接', + connectionError: 'WebSocket連接錯誤', + connectionFailed: 'WebSocket連接失敗', + notConnected: 'WebSocket未連接,請稍後重試', + imageUploadFailed: '圖片上傳失敗', }, }, knowledge: {