feat: model fallback chain (#2017) (#2018)

This commit is contained in:
Junyan Chin
2026-03-12 03:33:05 +08:00
committed by GitHub
parent 89064a9d5b
commit 79311ccde3
15 changed files with 534 additions and 113 deletions
+131 -46
View File
@@ -4,6 +4,7 @@ import json
import copy
import typing
from .. import runner
from ..modelmgr import requester as modelmgr_requester
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.rag.context as rag_context
@@ -26,19 +27,109 @@ Respond in the same language as the user's input.
@runner.runner_class('local-agent')
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器"""
"""Local agent request runner"""
class ToolCallTracker:
"""工具调用追踪器"""
async def _get_model_candidates(
self,
query: pipeline_query.Query,
) -> list[modelmgr_requester.RuntimeLLMModel]:
"""Build ordered list of models to try: primary model + fallback models."""
candidates = []
def __init__(self):
self.active_calls: dict[str, dict] = {}
self.completed_calls: list[provider_message.ToolCall] = []
# Primary model
if query.use_llm_model_uuid:
try:
primary = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid)
candidates.append(primary)
except ValueError:
self.ap.logger.warning(f'Primary model {query.use_llm_model_uuid} not found')
# Fallback models
fallback_uuids = (query.variables or {}).get('_fallback_model_uuids', [])
for fb_uuid in fallback_uuids:
try:
fb_model = await self.ap.model_mgr.get_model_by_uuid(fb_uuid)
candidates.append(fb_model)
except ValueError:
self.ap.logger.warning(f'Fallback model {fb_uuid} not found, skipping')
return candidates
async def _invoke_with_fallback(
self,
query: pipeline_query.Query,
candidates: list[modelmgr_requester.RuntimeLLMModel],
messages: list,
funcs: list,
remove_think: bool,
) -> tuple[provider_message.Message, modelmgr_requester.RuntimeLLMModel]:
"""Try non-streaming invocation with sequential fallback. Returns (message, model_used)."""
last_error = None
for model in candidates:
try:
msg = await model.provider.invoke_llm(
query,
model,
messages,
funcs if model.model_entity.abilities.__contains__('func_call') else [],
extra_args=model.model_entity.extra_args,
remove_think=remove_think,
)
return msg, model
except Exception as e:
last_error = e
self.ap.logger.warning(f'Model {model.model_entity.name} failed: {e}, trying next fallback...')
raise last_error or RuntimeError('No model candidates available')
async def _invoke_stream_with_fallback(
self,
query: pipeline_query.Query,
candidates: list[modelmgr_requester.RuntimeLLMModel],
messages: list,
funcs: list,
remove_think: bool,
) -> tuple[typing.AsyncGenerator, modelmgr_requester.RuntimeLLMModel]:
"""Try streaming invocation with sequential fallback. Returns (stream_generator, model_used).
Fallback is only possible before any chunks have been yielded to the client.
Once streaming starts, the model is committed.
"""
last_error = None
for model in candidates:
try:
stream = model.provider.invoke_llm_stream(
query,
model,
messages,
funcs if model.model_entity.abilities.__contains__('func_call') else [],
extra_args=model.model_entity.extra_args,
remove_think=remove_think,
)
# Attempt to get the first chunk to verify the stream works
first_chunk = await stream.__anext__()
async def _chain_stream(first, rest):
yield first
async for chunk in rest:
yield chunk
return _chain_stream(first_chunk, stream), model
except StopAsyncIteration:
# Empty stream — treat as success (model returned nothing)
async def _empty_stream():
return
yield # make it a generator
return _empty_stream(), model
except Exception as e:
last_error = e
self.ap.logger.warning(f'Model {model.model_entity.name} stream failed: {e}, trying next fallback...')
raise last_error or RuntimeError('No model candidates available')
async def run(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message | provider_message.MessageChunk, None]:
"""运行请求"""
"""Run request"""
pending_tool_calls = []
# Get knowledge bases list (new field)
@@ -119,51 +210,51 @@ class LocalAgentRunner(runner.RequestRunner):
remove_think = query.pipeline_config['output'].get('misc', '').get('remove-think')
use_llm_model = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid)
# Build ordered candidate list (primary + fallbacks)
candidates = await self._get_model_candidates(query)
if not candidates:
raise RuntimeError('No LLM model configured for local-agent runner')
self.ap.logger.debug(
f'localagent req: query={query.query_id} req_messages={req_messages} use_llm_model={query.use_llm_model_uuid}'
f'localagent req: query={query.query_id} req_messages={req_messages} '
f'candidates={[m.model_entity.name for m in candidates]}'
)
if not is_stream:
# 非流式输出,直接请求
msg = await use_llm_model.provider.invoke_llm(
# Non-streaming: invoke with fallback
msg, use_llm_model = await self._invoke_with_fallback(
query,
use_llm_model,
candidates,
req_messages,
query.use_funcs,
extra_args=use_llm_model.model_entity.extra_args,
remove_think=remove_think,
remove_think,
)
yield msg
final_msg = msg
else:
# 流式输出,需要处理工具调用
# Streaming: invoke with fallback
tool_calls_map: dict[str, provider_message.ToolCall] = {}
msg_idx = 0
accumulated_content = '' # 从开始累积的所有内容
accumulated_content = ''
last_role = 'assistant'
msg_sequence = 1
async for msg in use_llm_model.provider.invoke_llm_stream(
stream_src, use_llm_model = await self._invoke_stream_with_fallback(
query,
use_llm_model,
candidates,
req_messages,
query.use_funcs,
extra_args=use_llm_model.model_entity.extra_args,
remove_think=remove_think,
):
remove_think,
)
async for msg in stream_src:
msg_idx = msg_idx + 1
# 记录角色
if msg.role:
last_role = msg.role
# 累积内容
if msg.content:
accumulated_content += msg.content
# 处理工具调用
if msg.tool_calls:
for tool_call in msg.tool_calls:
if tool_call.id not in tool_calls_map:
@@ -175,21 +266,18 @@ class LocalAgentRunner(runner.RequestRunner):
),
)
if tool_call.function and tool_call.function.arguments:
# 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
# continue
# 每8个chunk或最后一个chunk时,输出所有累积的内容
if msg_idx % 8 == 0 or msg.is_final:
msg_sequence += 1
yield provider_message.MessageChunk(
role=last_role,
content=accumulated_content, # 输出所有累积内容
content=accumulated_content,
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
is_final=msg.is_final,
msg_sequence=msg_sequence,
)
# 创建最终消息用于后续处理
final_msg = provider_message.MessageChunk(
role=last_role,
content=accumulated_content,
@@ -204,7 +292,8 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(final_msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
# Once a model succeeds, commit to it for the tool call loop
# (no fallback mid-conversation — different models may interpret tool results differently)
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
@@ -245,7 +334,6 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = provider_message.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
yield err_msg
@@ -253,39 +341,38 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(err_msg)
self.ap.logger.debug(
f'localagent req: query={query.query_id} req_messages={req_messages} use_llm_model={query.use_llm_model_uuid}'
f'localagent req: query={query.query_id} req_messages={req_messages} '
f'use_llm_model={use_llm_model.model_entity.name}'
)
if is_stream:
tool_calls_map = {}
msg_idx = 0
accumulated_content = '' # 从开始累积的所有内容
accumulated_content = ''
last_role = 'assistant'
msg_sequence = first_end_sequence
async for msg in use_llm_model.provider.invoke_llm_stream(
tool_stream_src = use_llm_model.provider.invoke_llm_stream(
query,
use_llm_model,
req_messages,
query.use_funcs,
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
extra_args=use_llm_model.model_entity.extra_args,
remove_think=remove_think,
):
)
async for msg in tool_stream_src:
msg_idx += 1
# 记录角色
if msg.role:
last_role = msg.role
# 第一次请求工具调用时的内容
# Prepend first-round content on first chunk of tool-call round
if msg_idx == 1:
accumulated_content = first_content if first_content is not None else accumulated_content
# 累积内容
if msg.content:
accumulated_content += msg.content
# 处理工具调用
if msg.tool_calls:
for tool_call in msg.tool_calls:
if tool_call.id not in tool_calls_map:
@@ -297,15 +384,13 @@ class LocalAgentRunner(runner.RequestRunner):
),
)
if tool_call.function and tool_call.function.arguments:
# 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
# 每8个chunk或最后一个chunk时,输出所有累积的内容
if msg_idx % 8 == 0 or msg.is_final:
msg_sequence += 1
yield provider_message.MessageChunk(
role=last_role,
content=accumulated_content, # 输出所有累积内容
content=accumulated_content,
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
is_final=msg.is_final,
msg_sequence=msg_sequence,
@@ -318,12 +403,12 @@ class LocalAgentRunner(runner.RequestRunner):
msg_sequence=msg_sequence,
)
else:
# 处理完所有调用,再次请求
# Non-streaming: use committed model directly (no fallback in tool loop)
msg = await use_llm_model.provider.invoke_llm(
query,
use_llm_model,
req_messages,
query.use_funcs,
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
extra_args=use_llm_model.model_entity.extra_args,
remove_think=remove_think,
)