From 0eac9135c07bc7897cddf70859c8018f4d21b01d Mon Sep 17 00:00:00 2001 From: fdc Date: Mon, 30 Jun 2025 17:58:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E5=A4=84=E7=90=86=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fix.MD | 47 +++++++ pkg/core/entities.py | 2 +- pkg/pipeline/process/handlers/chat.py | 40 +++++- pkg/pipeline/respback/respback.py | 22 ++++ pkg/platform/adapter.py | 21 ++++ pkg/provider/entities.py | 83 +++++++++++++ pkg/provider/modelmgr/requester.py | 5 +- pkg/provider/modelmgr/requesters/chatcmpl.py | 90 ++++++++++++-- pkg/provider/runners/localagent.py | 123 +++++++++++++++---- 9 files changed, 387 insertions(+), 46 deletions(-) create mode 100644 fix.MD diff --git a/fix.MD b/fix.MD new file mode 100644 index 00000000..51927eb9 --- /dev/null +++ b/fix.MD @@ -0,0 +1,47 @@ +## 底层模型请求器 + +- pkg/provider/modelmgr/requesters/... + +给 invoke_llm 加个 stream: bool 参数,并允许 invoke_llm 返回两种参数:原来的 llm_entities.Message(非流式)和 返回 llm_entities.MessageChunk(流式,需要新增这个实体)的 AsyncGenerator + +## Runner + +- pkg/provider/runners/... + +每个runner的run方法也允许传入stream: bool。 + +现在的run方法本身就是生成器(AsyncGenerator),因为agent是有多回合的,会生成多条Message。但现在需要支持文本消息可以分段。 + +现在run方法应该返回 AsyncGenerator[ Union[ Message, AsyncGenerator[MessageChunk] ] ]。 + +对于 local agent 的实现上,调用模型invoke_llm时,传入stream,当发现模型返回的是Message时,即按照现在的写法操作Message;当返回的是 AsyncGenerator 时,需要 yield MessageChunk 给上层,同时需要注意判断工具调用。 + +## 流水线 + +- pkg/pipeline/process/handlers/chat.py + +之前这里就已经有一个生成器写法了,用于处理 AsyncGenerator[Message],但现在需要加上一个判断,如果yield出来的是 Message 则按照现在的处理;如果yield出来的是 AsyncGenerator,那么就需要再 async for 一层; + +因为流水线是基于责任链模式设计的,这里的生成结果只需要放入 Query 对象中,供下一层处理。 + +所以需要在 Query 对象中支持存入MessageChunk,现在只支持存 Message 到 resp_messages,这里得设计一下。 + +## 回复阶段 + +最终会在 pkg/pipeline/respback/respback.py 中检出 query 中的信息并发回,这里也要改成支持 MessagChunk 的。 + +这里应该判断适配器是否支持流式,若不支持,应该等待所有 MessageChunk 生成,拼接成 Message 再转换成 MessageChain 调用 send_message(); + +若支持,则uuid生成一个message id,使用该message id调用适配器的 reply_message_chunk 方法。 + +## 机器人适配器 + +因为机器人可能会由于用户配置项不同而表现为对流式的支持性不同,比如飞书默认不支持流式,需要用户额外配置卡片。 + +所以需要新增一个方法 `is_stream_output_supported() -> bool`,这个让每个适配器来判断并返回是否支持流式; + +在发送时,得加两个方法 `send_message_chunk(target_type: str, target_id: str, message_id: , message: MessageChain)` + +message_id 确定同一条消息,由调用方生成; + +`reply_message_chunk(message_source: MessageEvent, message: MessageChain)` \ No newline at end of file diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 4caf18ed..4873d9ce 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -87,7 +87,7 @@ class Query(pydantic.BaseModel): """使用的函数,由前置处理器阶段设置""" resp_messages: ( - typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] + typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] | typing.Optional[list[llm_entities.MessageChunk]] ) = [] """由Process阶段生成的回复消息对象列表""" diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 35fa1611..c90d283b 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -1,5 +1,6 @@ from __future__ import annotations +from itertools import accumulate import typing import traceback @@ -59,6 +60,8 @@ class ChatMessageHandler(handler.MessageHandler): text_length = 0 + is_stream = query.adapter.is_stream_output_supported() + try: for r in runner_module.preregistered_runners: if r.name == query.pipeline_config['ai']['runner']['runner']: @@ -66,18 +69,43 @@ class ChatMessageHandler(handler.MessageHandler): break else: raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}') + if is_stream: + accumulated_messages = [] + async for result in runner.run(query): + accumulated_messages.append(result) + query.resp_messages.append(result) - async for result in runner.run(query): - query.resp_messages.append(result) + self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}') - self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') + if result.content is not None: + text_length += len(result.content) - if result.content is not None: - text_length += len(result.content) + # current_chain = platform_message.MessageChain([]) + # for msg in accumulated_messages: + # if msg.content is not None: + # current_chain.append(platform_message.Plain(msg.content)) + # query.resp_message_chain = [current_chain] + + + + + - yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + else: + + async for result in runner.run(query): + query.resp_messages.append(result) + + self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') + + if result.content is not None: + text_length += len(result.content) + + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) query.session.using_conversation.messages.append(query.user_message) + query.session.using_conversation.messages.extend(query.resp_messages) except Exception as e: self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}') diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 39d3abb1..7654896b 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -36,6 +36,28 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin = query.pipeline_config['output']['misc']['quote-origin'] + has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) + if has_chunks and hasattr(query.adapter,'reply_message_chunk'): + + async def message_generator(): + for msg in query.resp_messages: + if isinstance(msg, llm_entities.MessageChunk): + yield msg.content + else: + yield msg.content + await query.adapter.reply_message_chunk( + message_source=query.message_event, + message_id=query.message_event.message_id, + message_generator=message_generator(), + quote_origin=quote_origin, + ) + else: + await query.adapter.reply_message( + message_source=query.message_event, + message=query.resp_message_chain[-1], + quote_origin=quote_origin, + ) + await query.adapter.reply_message( message_source=query.message_event, message=query.resp_message_chain[-1], diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index f28ad3dc..c841ae98 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -49,11 +49,27 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): async def reply_message( self, message_source: platform_events.MessageEvent, + message_id: int, message: platform_message.MessageChain, quote_origin: bool = False, ): """回复消息 + Args: + message_source (platform.types.MessageEvent): 消息源事件 + message_id (int): 消息ID + message (platform.types.MessageChain): 消息链 + quote_origin (bool, optional): 是否引用原消息. Defaults to False. + """ + raise NotImplementedError + + async def reply_message_chunk( + self, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, + quote_origin: bool = False, + ): + """回复消息(流式输出) Args: message_source (platform.types.MessageEvent): 消息源事件 message (platform.types.MessageChain): 消息链 @@ -94,6 +110,11 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): async def run_async(self): """异步运行""" raise NotImplementedError + + + async def is_stream_output_supported(self) -> bool: + """是否支持流式输出""" + return False async def kill(self) -> bool: """关闭适配器 diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 94b812d9..a149fea3 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -125,6 +125,89 @@ class Message(pydantic.BaseModel): return platform_message.MessageChain(mc) +class MessageChunk(pydantic.BaseModel): + """消息""" + + role: str # user, system, assistant, tool, command, plugin + """消息的角色""" + + name: typing.Optional[str] = None + """名称,仅函数调用返回时设置""" + + all_content: typing.Optional[str] = None + """所有内容""" + + content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None + """内容""" + + # tool_calls: typing.Optional[list[ToolCall]] = None + """工具调用""" + + tool_call_id: typing.Optional[str] = None + + tool_calls: typing.Optional[list[ToolCallChunk]] = None + + is_final: bool = False + + def readable_str(self) -> str: + if self.content is not None: + return str(self.role) + ': ' + str(self.get_content_platform_message_chain()) + elif self.tool_calls is not None: + return f'调用工具: {self.tool_calls[0].id}' + else: + return '未知消息' + + def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None: + """将内容转换为平台消息 MessageChain 对象 + + Args: + prefix_text (str): 首个文字组件的前缀文本 + """ + + if self.content is None: + return None + elif isinstance(self.content, str): + return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)]) + elif isinstance(self.content, list): + mc = [] + for ce in self.content: + if ce.type == 'text': + mc.append(platform_message.Plain(ce.text)) + elif ce.type == 'image_url': + if ce.image_url.url.startswith('http'): + mc.append(platform_message.Image(url=ce.image_url.url)) + else: # base64 + b64_str = ce.image_url.url + + if b64_str.startswith('data:'): + b64_str = b64_str.split(',')[1] + + mc.append(platform_message.Image(base64=b64_str)) + + # 找第一个文字组件 + if prefix_text: + for i, c in enumerate(mc): + if isinstance(c, platform_message.Plain): + mc[i] = platform_message.Plain(prefix_text + c.text) + break + else: + mc.insert(0, platform_message.Plain(prefix_text)) + + return platform_message.MessageChain(mc) + + +class ToolCallChunk(pydantic.BaseModel): + """工具调用""" + + id: str + """工具调用ID""" + + type: str + """工具调用类型""" + + function: FunctionCall + """函数调用""" + class Prompt(pydantic.BaseModel): """供AI使用的Prompt""" diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 244f4c82..3e5e791f 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -60,8 +60,9 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): model: RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, + stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.Message: + ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: """调用API Args: @@ -71,6 +72,6 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: - llm_entities.Message: 返回消息对象 + llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: 返回消息对象 """ pass diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 513086e5..22931611 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -57,13 +57,35 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): message = llm_entities.Message(**chatcmpl_message) return message + + async def _make_msg_chunk( + self, + chat_completion: chat_completion.ChatCompletion, + ) -> llm_entities.MessageChunk: + choice = chat_completion.choices[0] + delta = choice.delta.model_dump() + # 确保 role 字段存在且不为 None + if 'role' not in delta or delta['role'] is None: + delta['role'] = 'assistant' + + reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None + + # deepseek的reasoner模型 + if reasoning_content is not None: + delta['content'] = '\n' + reasoning_content + '\n\n' + delta['content'] + + message = llm_entities.MessageChunk(**delta) + + return message + async def _closure( self, query: core_entities.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, + stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() @@ -91,13 +113,42 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): args['messages'] = messages - # 发送请求 - resp = await self._req(args, extra_body=extra_args) + if stream: + current_content = '' + async for chunk in await self._req(args, extra_body=extra_args): - # 处理请求结果 - message = await self._make_msg(resp) + # 处理流式消息 + delta_message = await self._make_msg_chunk( + chat_completion=chunk, + ) + if delta_message.content: + current_content += delta_message.content + delta_message.all_content = current_content + + # 检查是否为最后一个块 + if chunk.choices[0].finish_reason is not None: + delta_message.is_final = True - return message + yield delta_message + return + + else: + + # 非流式请求 + resp = await self._req(args, extra_body=extra_args) + # 处理请求结果 + # 发送请求 + resp = await self._req(args, extra_body=extra_args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message + + + + + async def invoke_llm( self, @@ -105,8 +156,9 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, + stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.Message: + ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) @@ -119,13 +171,25 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): req_messages.append(msg_dict) try: - return await self._closure( - query=query, - req_messages=req_messages, - use_model=model, - use_funcs=funcs, - extra_args=extra_args, - ) + if stream: + async for item in self._closure( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + stream=stream, + extra_args=extra_args, + ): + yield item + return + else: + return await self._closure( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + extra_args=extra_args, + ) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 7d5e04c5..02b2db16 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from ssl import ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE import typing from .. import runner @@ -12,26 +13,68 @@ from .. import entities as llm_entities class LocalAgentRunner(runner.RequestRunner): """本地Agent请求运行器""" - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + class ToolCallTracker: + """工具调用追踪器""" + def __init__(self): + self.active_calls: dict[str,dict] = {} + self.completed_calls: list[llm_entities.ToolCall] = [] + + async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]: """运行请求""" pending_tool_calls = [] req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] - # 首次请求 - msg = await query.use_llm_model.requester.invoke_llm( - query, - query.use_llm_model, - req_messages, - query.use_funcs, - extra_args=query.use_llm_model.model_entity.extra_args, - ) + is_stream = query.adapter.is_stream_output_supported() + # while True: + # pass + if not is_stream: + # 非流式输出,直接请求 + msg = await query.use_llm_model.requester.invoke_llm( + query, + query.use_llm_model, + req_messages, + query.use_funcs, + extra_args=query.use_llm_model.model_entity.extra_args, + ) + yield msg + final_msg = msg + else: + # 流式输出,需要处理工具调用 + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + async for msg in await query.use_llm_model.requester.invoke_llm( + query, + query.use_llm_model, + req_messages, + query.use_funcs, + stream=is_stream, + extra_args=query.use_llm_model.model_entity.extra_args, + ): + yield msg + if msg.tool_calls: + for tool_call in msg.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', + arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + final_msg = llm_entities.Message( + role=msg.role, + content=msg.all_content, + tool_calls=list(tool_calls_map.values()), + ) - yield msg + + pending_tool_calls = final_msg.tool_calls - pending_tool_calls = msg.tool_calls - - req_messages.append(msg) + req_messages.append(final_msg) # 持续请求,只要还有待处理的工具调用就继续处理调用 while pending_tool_calls: @@ -60,17 +103,49 @@ class LocalAgentRunner(runner.RequestRunner): req_messages.append(err_msg) - # 处理完所有调用,再次请求 - msg = await query.use_llm_model.requester.invoke_llm( - query, - query.use_llm_model, - req_messages, - query.use_funcs, - extra_args=query.use_llm_model.model_entity.extra_args, - ) + if is_stream: + tool_calls_map = {} + async for msg in await query.use_llm_model.requester.invoke_llm( + query, + query.use_llm_model, + req_messages, + query.use_funcs, + stream=is_stream, + extra_args=query.use_llm_model.model_entity.extra_args, + ): + yield msg + if msg.tool_calls: + for tool_call in msg.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', + arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + final_msg = llm_entities.Message( + role=msg.role, + content=all_content, + tool_calls=list(tool_calls_map.values()), + ) + else: + # 处理完所有调用,再次请求 + msg = await query.use_llm_model.requester.invoke_llm( + query, + query.use_llm_model, + req_messages, + query.use_funcs, + extra_args=query.use_llm_model.model_entity.extra_args, + ) - yield msg + yield msg + final_msg = msg - pending_tool_calls = msg.tool_calls + pending_tool_calls = final_msg.tool_calls - req_messages.append(msg) + req_messages.append(final_msg)