diff --git a/pkg/api/http/controller/groups/pipelines/webchat.py b/pkg/api/http/controller/groups/pipelines/webchat.py index 005738db..a3bf8585 100644 --- a/pkg/api/http/controller/groups/pipelines/webchat.py +++ b/pkg/api/http/controller/groups/pipelines/webchat.py @@ -1,3 +1,5 @@ +import json + import quart from ... import group @@ -9,10 +11,16 @@ class WebChatDebugRouterGroup(group.RouterGroup): @self.route('/send', methods=['POST']) async def send_message(pipeline_uuid: str) -> str: """发送调试消息到流水线""" + + async def stream_generator(generator): + async for message in generator: + yield rf"data:{json.dumps({'message': message})}\n\n" + yield "data:{'type': 'end'}\n\n''" try: data = await quart.request.get_json() session_type = data.get('session_type', 'person') message_chain_obj = data.get('message', []) + is_stream = data.get('is_stream', False) if not message_chain_obj: return self.http_status(400, -1, 'message is required') @@ -25,13 +33,29 @@ class WebChatDebugRouterGroup(group.RouterGroup): if not webchat_adapter: return self.http_status(404, -1, 'WebChat adapter not found') - result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj) + if is_stream: + + generator = webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj, is_stream) + + return quart.Response( + stream_generator(generator), + mimetype='text/event-stream' + ) + + else: + # result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj) + result = None + async for message in webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj): + result = message + if result is not None: + return self.success( + data={ + 'message': result, + } + ) + else: + return self.http_status(400, -1, 'message is required') - return self.success( - data={ - 'message': result, - } - ) except Exception as e: return self.http_status(500, -1, f'Internal server error: {str(e)}') diff --git a/pkg/platform/sources/webchat.py b/pkg/platform/sources/webchat.py index 51b0479f..7fd7bb3b 100644 --- a/pkg/platform/sources/webchat.py +++ b/pkg/platform/sources/webchat.py @@ -25,11 +25,13 @@ class WebChatSession: id: str message_lists: dict[str, list[WebChatMessage]] = {} resp_waiters: dict[int, asyncio.Future[WebChatMessage]] + resp_queues = dict[int, asyncio.Queue[WebChatMessage]] def __init__(self, id: str): self.id = id self.message_lists = {} self.resp_waiters = {} + self.resp_queues = {} def get_message_list(self, pipeline_uuid: str) -> list[WebChatMessage]: if pipeline_uuid not in self.message_lists: @@ -108,6 +110,35 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): return message_data.model_dump() + async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + message_id: str, + message: platform_message.MessageChain, + quote_origin: bool = False, + is_fianl: bool = False, + ) -> dict: + """回复消息""" + message_data = WebChatMessage( + id=-1, + role='assistant', + content=str(message), + message_chain=[component.__dict__ for component in message], + timestamp=datetime.now().isoformat(), + ) + + # notify waiter + if isinstance(message_source, platform_events.FriendMessage): + queue = self.webchat_person_session.resp_queues[message_source.message_chain.message_id] + elif isinstance(message_source, platform_events.GroupMessage): + queue = self.webchat_group_session.resp_queues[message_source.message_chain.message_id] + + queue.put(message_data) + if is_fianl: + queue.put(None) + + return message_data.model_dump() + def register_listener( self, event_type: typing.Type[platform_events.Event], @@ -140,7 +171,8 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): await self.logger.info('WebChat调试适配器正在停止') async def send_webchat_message( - self, pipeline_uuid: str, session_type: str, message_chain_obj: typing.List[dict] + self, pipeline_uuid: str, session_type: str, message_chain_obj: typing.List[dict], + is_stream: bool = False, ) -> dict: """发送调试消息到流水线""" if session_type == 'person': @@ -188,18 +220,29 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) - # set waiter - waiter = asyncio.Future[WebChatMessage]() - use_session.resp_waiters[message_id] = waiter - waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id)) + if is_stream: + queue = use_session.resp_queues[message_id] + while True: + resp_message = await queue.get() + if resp_message is None: + resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1 + use_session.get_message_list(pipeline_uuid).append(resp_message) + break + yield resp_message.model_dump() - resp_message = await waiter + else: + # set waiter + waiter = asyncio.Future[WebChatMessage]() + use_session.resp_waiters[message_id] = waiter + waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id)) - resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1 + resp_message = await waiter - use_session.get_message_list(pipeline_uuid).append(resp_message) + resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1 - return resp_message.model_dump() + use_session.get_message_list(pipeline_uuid).append(resp_message) + + yield resp_message.model_dump() def get_webchat_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]: """获取调试消息历史""" diff --git a/pkg/platform/sources/webchat.yaml b/pkg/platform/sources/webchat.yaml index 4e8cc38e..0b1d4c29 100644 --- a/pkg/platform/sources/webchat.yaml +++ b/pkg/platform/sources/webchat.yaml @@ -9,7 +9,18 @@ metadata: en_US: "WebChat adapter for pipeline debugging" zh_Hans: "用于流水线调试的网页聊天适配器" icon: "" -spec: {} +spec: + config: + - name: enable-stream-reply + label: + en_US: Enable Stream Reply Mode + zh_Hans: 启用电报流式回复模式 + description: + en_US: If enabled, the bot will use the stream of telegram reply mode + zh_Hans: 如果启用,将使用电报流式方式来回复内容 + type: boolean + required: true + default: false execution: python: path: "webchat.py" diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index a86cdbe8..34c9f61f 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -34,6 +34,7 @@ import { } from '@/app/infra/entities/api'; import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest'; import { GetBotLogsResponse } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse'; +import {boolean} from "zod"; type JSONValue = string | number | boolean | JSONObject | JSONArray | null; interface JSONObject { @@ -309,12 +310,14 @@ class HttpClient { messageChain: object[], pipelineId: string, timeout: number = 15000, + is_stream: boolean = false, ): Promise { return this.post( `/api/v1/pipelines/${pipelineId}/chat/send`, { session_type: sessionType, message: messageChain, + is_stream: is_stream, }, { timeout,