feat: 实现流式消息处理支持

This commit is contained in:
fdc
2025-06-30 17:58:18 +08:00
parent b43dd95dc6
commit 0eac9135c0
9 changed files with 387 additions and 46 deletions

47
fix.MD Normal file
View File

@@ -0,0 +1,47 @@
## 底层模型请求器
- pkg/provider/modelmgr/requesters/...
给 invoke_llm 加个 stream: bool 参数,并允许 invoke_llm 返回两种参数:原来的 llm_entities.Message非流式和 返回 llm_entities.MessageChunk流式需要新增这个实体的 AsyncGenerator
## Runner
- pkg/provider/runners/...
每个runner的run方法也允许传入stream: bool。
现在的run方法本身就是生成器AsyncGenerator因为agent是有多回合的会生成多条Message。但现在需要支持文本消息可以分段。
现在run方法应该返回 AsyncGenerator[ Union[ Message, AsyncGenerator[MessageChunk] ] ]。
对于 local agent 的实现上调用模型invoke_llm时传入stream当发现模型返回的是Message时即按照现在的写法操作Message当返回的是 AsyncGenerator 时,需要 yield MessageChunk 给上层,同时需要注意判断工具调用。
## 流水线
- pkg/pipeline/process/handlers/chat.py
之前这里就已经有一个生成器写法了,用于处理 AsyncGenerator[Message]但现在需要加上一个判断如果yield出来的是 Message 则按照现在的处理如果yield出来的是 AsyncGenerator那么就需要再 async for 一层;
因为流水线是基于责任链模式设计的,这里的生成结果只需要放入 Query 对象中,供下一层处理。
所以需要在 Query 对象中支持存入MessageChunk现在只支持存 Message 到 resp_messages这里得设计一下。
## 回复阶段
最终会在 pkg/pipeline/respback/respback.py 中检出 query 中的信息并发回,这里也要改成支持 MessagChunk 的。
这里应该判断适配器是否支持流式,若不支持,应该等待所有 MessageChunk 生成,拼接成 Message 再转换成 MessageChain 调用 send_message()
若支持则uuid生成一个message id使用该message id调用适配器的 reply_message_chunk 方法。
## 机器人适配器
因为机器人可能会由于用户配置项不同而表现为对流式的支持性不同,比如飞书默认不支持流式,需要用户额外配置卡片。
所以需要新增一个方法 `is_stream_output_supported() -> bool`,这个让每个适配器来判断并返回是否支持流式;
在发送时,得加两个方法 `send_message_chunk(target_type: str, target_id: str, message_id: , message: MessageChain)`
message_id 确定同一条消息,由调用方生成;
`reply_message_chunk(message_source: MessageEvent, message: MessageChain)`

View File

@@ -87,7 +87,7 @@ class Query(pydantic.BaseModel):
"""使用的函数,由前置处理器阶段设置"""
resp_messages: (
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]]
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] | typing.Optional[list[llm_entities.MessageChunk]]
) = []
"""由Process阶段生成的回复消息对象列表"""

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from itertools import accumulate
import typing
import traceback
@@ -59,6 +60,8 @@ class ChatMessageHandler(handler.MessageHandler):
text_length = 0
is_stream = query.adapter.is_stream_output_supported()
try:
for r in runner_module.preregistered_runners:
if r.name == query.pipeline_config['ai']['runner']['runner']:
@@ -66,18 +69,43 @@ class ChatMessageHandler(handler.MessageHandler):
break
else:
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
if is_stream:
accumulated_messages = []
async for result in runner.run(query):
accumulated_messages.append(result)
query.resp_messages.append(result)
async for result in runner.run(query):
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}')
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
if result.content is not None:
text_length += len(result.content)
# current_chain = platform_message.MessageChain([])
# for msg in accumulated_messages:
# if msg.content is not None:
# current_chain.append(platform_message.Plain(msg.content))
# query.resp_message_chain = [current_chain]
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
async for result in runner.run(query):
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')

View File

@@ -36,6 +36,28 @@ class SendResponseBackStage(stage.PipelineStage):
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages)
if has_chunks and hasattr(query.adapter,'reply_message_chunk'):
async def message_generator():
for msg in query.resp_messages:
if isinstance(msg, llm_entities.MessageChunk):
yield msg.content
else:
yield msg.content
await query.adapter.reply_message_chunk(
message_source=query.message_event,
message_id=query.message_event.message_id,
message_generator=message_generator(),
quote_origin=quote_origin,
)
else:
await query.adapter.reply_message(
message_source=query.message_event,
message=query.resp_message_chain[-1],
quote_origin=quote_origin,
)
await query.adapter.reply_message(
message_source=query.message_event,
message=query.resp_message_chain[-1],

View File

@@ -49,11 +49,27 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message_id: int,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
"""回复消息
Args:
message_source (platform.types.MessageEvent): 消息源事件
message_id (int): 消息ID
message (platform.types.MessageChain): 消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
"""
raise NotImplementedError
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
"""回复消息(流式输出)
Args:
message_source (platform.types.MessageEvent): 消息源事件
message (platform.types.MessageChain): 消息链
@@ -94,6 +110,11 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
async def run_async(self):
"""异步运行"""
raise NotImplementedError
async def is_stream_output_supported(self) -> bool:
"""是否支持流式输出"""
return False
async def kill(self) -> bool:
"""关闭适配器

View File

@@ -125,6 +125,89 @@ class Message(pydantic.BaseModel):
return platform_message.MessageChain(mc)
class MessageChunk(pydantic.BaseModel):
"""消息"""
role: str # user, system, assistant, tool, command, plugin
"""消息的角色"""
name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置"""
all_content: typing.Optional[str] = None
"""所有内容"""
content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
"""内容"""
# tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用"""
tool_call_id: typing.Optional[str] = None
tool_calls: typing.Optional[list[ToolCallChunk]] = None
is_final: bool = False
def readable_str(self) -> str:
if self.content is not None:
return str(self.role) + ': ' + str(self.get_content_platform_message_chain())
elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}'
else:
return '未知消息'
def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None:
"""将内容转换为平台消息 MessageChain 对象
Args:
prefix_text (str): 首个文字组件的前缀文本
"""
if self.content is None:
return None
elif isinstance(self.content, str):
return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)])
elif isinstance(self.content, list):
mc = []
for ce in self.content:
if ce.type == 'text':
mc.append(platform_message.Plain(ce.text))
elif ce.type == 'image_url':
if ce.image_url.url.startswith('http'):
mc.append(platform_message.Image(url=ce.image_url.url))
else: # base64
b64_str = ce.image_url.url
if b64_str.startswith('data:'):
b64_str = b64_str.split(',')[1]
mc.append(platform_message.Image(base64=b64_str))
# 找第一个文字组件
if prefix_text:
for i, c in enumerate(mc):
if isinstance(c, platform_message.Plain):
mc[i] = platform_message.Plain(prefix_text + c.text)
break
else:
mc.insert(0, platform_message.Plain(prefix_text))
return platform_message.MessageChain(mc)
class ToolCallChunk(pydantic.BaseModel):
"""工具调用"""
id: str
"""工具调用ID"""
type: str
"""工具调用类型"""
function: FunctionCall
"""函数调用"""
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""

View File

@@ -60,8 +60,9 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
model: RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]:
"""调用API
Args:
@@ -71,6 +72,6 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
Returns:
llm_entities.Message: 返回消息对象
llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: 返回消息对象
"""
pass

View File

@@ -57,13 +57,35 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
message = llm_entities.Message(**chatcmpl_message)
return message
async def _make_msg_chunk(
self,
chat_completion: chat_completion.ChatCompletion,
) -> llm_entities.MessageChunk:
choice = chat_completion.choices[0]
delta = choice.delta.model_dump()
# 确保 role 字段存在且不为 None
if 'role' not in delta or delta['role'] is None:
delta['role'] = 'assistant'
reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None
# deepseek的reasoner模型
if reasoning_content is not None:
delta['content'] = '<think>\n' + reasoning_content + '\n</think>\n' + delta['content']
message = llm_entities.MessageChunk(**delta)
return message
async def _closure(
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
@@ -91,13 +113,42 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
args['messages'] = messages
# 发送请求
resp = await self._req(args, extra_body=extra_args)
if stream:
current_content = ''
async for chunk in await self._req(args, extra_body=extra_args):
# 处理请求结果
message = await self._make_msg(resp)
# 处理流式消息
delta_message = await self._make_msg_chunk(
chat_completion=chunk,
)
if delta_message.content:
current_content += delta_message.content
delta_message.all_content = current_content
# 检查是否为最后一个块
if chunk.choices[0].finish_reason is not None:
delta_message.is_final = True
return message
yield delta_message
return
else:
# 非流式请求
resp = await self._req(args, extra_body=extra_args)
# 处理请求结果
# 发送请求
resp = await self._req(args, extra_body=extra_args)
# 处理请求结果
message = await self._make_msg(resp)
return message
async def invoke_llm(
self,
@@ -105,8 +156,9 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
for m in messages:
msg_dict = m.dict(exclude_none=True)
@@ -119,13 +171,25 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
req_messages.append(msg_dict)
try:
return await self._closure(
query=query,
req_messages=req_messages,
use_model=model,
use_funcs=funcs,
extra_args=extra_args,
)
if stream:
async for item in self._closure(
query=query,
req_messages=req_messages,
use_model=model,
use_funcs=funcs,
stream=stream,
extra_args=extra_args,
):
yield item
return
else:
return await self._closure(
query=query,
req_messages=req_messages,
use_model=model,
use_funcs=funcs,
extra_args=extra_args,
)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import json
from ssl import ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
import typing
from .. import runner
@@ -12,26 +13,68 @@ from .. import entities as llm_entities
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器"""
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
class ToolCallTracker:
"""工具调用追踪器"""
def __init__(self):
self.active_calls: dict[str,dict] = {}
self.completed_calls: list[llm_entities.ToolCall] = []
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]:
"""运行请求"""
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
extra_args=query.use_llm_model.model_entity.extra_args,
)
is_stream = query.adapter.is_stream_output_supported()
# while True:
# pass
if not is_stream:
# 非流式输出,直接请求
msg = await query.use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
extra_args=query.use_llm_model.model_entity.extra_args,
)
yield msg
final_msg = msg
else:
# 流式输出,需要处理工具调用
tool_calls_map: dict[str, llm_entities.ToolCall] = {}
async for msg in await query.use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
stream=is_stream,
extra_args=query.use_llm_model.model_entity.extra_args,
):
yield msg
if msg.tool_calls:
for tool_call in msg.tool_calls:
if tool_call.id not in tool_calls_map:
tool_calls_map[tool_call.id] = llm_entities.ToolCall(
id=tool_call.id,
type=tool_call.type,
function=llm_entities.FunctionCall(
name=tool_call.function.name if tool_call.function else '',
arguments=''
),
)
if tool_call.function and tool_call.function.arguments:
# 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
final_msg = llm_entities.Message(
role=msg.role,
content=msg.all_content,
tool_calls=list(tool_calls_map.values()),
)
yield msg
pending_tool_calls = final_msg.tool_calls
pending_tool_calls = msg.tool_calls
req_messages.append(msg)
req_messages.append(final_msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
@@ -60,17 +103,49 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
extra_args=query.use_llm_model.model_entity.extra_args,
)
if is_stream:
tool_calls_map = {}
async for msg in await query.use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
stream=is_stream,
extra_args=query.use_llm_model.model_entity.extra_args,
):
yield msg
if msg.tool_calls:
for tool_call in msg.tool_calls:
if tool_call.id not in tool_calls_map:
tool_calls_map[tool_call.id] = llm_entities.ToolCall(
id=tool_call.id,
type=tool_call.type,
function=llm_entities.FunctionCall(
name=tool_call.function.name if tool_call.function else '',
arguments=''
),
)
if tool_call.function and tool_call.function.arguments:
# 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
final_msg = llm_entities.Message(
role=msg.role,
content=all_content,
tool_calls=list(tool_calls_map.values()),
)
else:
# 处理完所有调用,再次请求
msg = await query.use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
req_messages,
query.use_funcs,
extra_args=query.use_llm_model.model_entity.extra_args,
)
yield msg
yield msg
final_msg = msg
pending_tool_calls = msg.tool_calls
pending_tool_calls = final_msg.tool_calls
req_messages.append(msg)
req_messages.append(final_msg)