from __future__ import annotations import json import typing from .. import runner from ...core import entities as core_entities from .. import entities as llm_entities @runner.runner_class('local-agent') class LocalAgentRunner(runner.RequestRunner): """本地Agent请求运行器""" class ToolCallTracker: """工具调用追踪器""" def __init__(self): self.active_calls: dict[str, dict] = {} self.completed_calls: list[llm_entities.ToolCall] = [] async def run( self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]: """运行请求""" pending_tool_calls = [] req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] try: is_stream = await query.adapter.is_stream_output_supported() except AttributeError: is_stream = False if not is_stream: # 非流式输出,直接请求 msg = await query.use_llm_model.requester.invoke_llm( query, query.use_llm_model, req_messages, query.use_funcs, extra_args=query.use_llm_model.model_entity.extra_args, ) yield msg final_msg = msg else: # 流式输出,需要处理工具调用 tool_calls_map: dict[str, llm_entities.ToolCall] = {} async for msg in query.use_llm_model.requester.invoke_llm_stream( query, query.use_llm_model, req_messages, query.use_funcs, extra_args=query.use_llm_model.model_entity.extra_args, ): yield msg if msg.tool_calls: for tool_call in msg.tool_calls: if tool_call.id not in tool_calls_map: tool_calls_map[tool_call.id] = llm_entities.ToolCall( id=tool_call.id, type=tool_call.type, function=llm_entities.FunctionCall( name=tool_call.function.name if tool_call.function else '', arguments='' ), ) if tool_call.function and tool_call.function.arguments: # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments final_msg = llm_entities.Message( role=msg.role, content=msg.all_content, tool_calls=list(tool_calls_map.values()), ) pending_tool_calls = final_msg.tool_calls req_messages.append(final_msg) # 持续请求,只要还有待处理的工具调用就继续处理调用 while pending_tool_calls: for tool_call in pending_tool_calls: try: func = tool_call.function parameters = json.loads(func.arguments) func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters) msg = llm_entities.Message( role='tool', content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id, ) yield msg req_messages.append(msg) except Exception as e: # 工具调用出错,添加一个报错信息到 req_messages err_msg = llm_entities.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id) yield err_msg req_messages.append(err_msg) if is_stream: tool_calls_map = {} async for msg in await query.use_llm_model.requester.invoke_llm_stream( query, query.use_llm_model, req_messages, query.use_funcs, extra_args=query.use_llm_model.model_entity.extra_args, ): yield msg if msg.tool_calls: for tool_call in msg.tool_calls: if tool_call.id not in tool_calls_map: tool_calls_map[tool_call.id] = llm_entities.ToolCall( id=tool_call.id, type=tool_call.type, function=llm_entities.FunctionCall( name=tool_call.function.name if tool_call.function else '', arguments='' ), ) if tool_call.function and tool_call.function.arguments: # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments final_msg = llm_entities.Message( role=msg.role, content=msg.all_content, tool_calls=list(tool_calls_map.values()), ) else: # 处理完所有调用,再次请求 msg = await query.use_llm_model.requester.invoke_llm( query, query.use_llm_model, req_messages, query.use_funcs, extra_args=query.use_llm_model.model_entity.extra_args, ) yield msg final_msg = msg pending_tool_calls = final_msg.tool_calls req_messages.append(final_msg)