mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-07 14:26:03 +00:00
530 lines
19 KiB
Python
530 lines
19 KiB
Python
"""DeerFlow LangGraph API Runner
|
||
|
||
参考 astrbot 的 deerflow_agent_runner 实现,适配 LangBot 的 Runner 接口。
|
||
|
||
特点:
|
||
- 使用 LangGraph HTTP API 接入 deer-flow 后端
|
||
- 自动管理 thread_id(按 session 隔离)
|
||
- 支持 SSE 流式响应解析
|
||
- 支持 streaming/非流式两种输出
|
||
- 处理 values / messages-tuple / custom 三种事件
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import typing
|
||
from collections import deque
|
||
from dataclasses import dataclass, field
|
||
|
||
|
||
from langbot.pkg.provider import runner
|
||
from langbot.pkg.core import app
|
||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||
from langbot.libs.deerflow_api import client, errors, stream_utils
|
||
|
||
|
||
_MAX_VALUES_HISTORY = 200
|
||
|
||
|
||
@dataclass
|
||
class _StreamState:
|
||
"""流式状态跟踪"""
|
||
latest_text: str = ''
|
||
prev_text_for_streaming: str = ''
|
||
clarification_text: str = ''
|
||
task_failures: list[str] = field(default_factory=list)
|
||
seen_message_ids: set[str] = field(default_factory=set)
|
||
seen_message_order: deque[str] = field(default_factory=deque)
|
||
no_id_message_fingerprints: dict[int, str] = field(default_factory=dict)
|
||
baseline_initialized: bool = False
|
||
has_values_text: bool = False
|
||
run_values_messages: list[dict[str, typing.Any]] = field(default_factory=list)
|
||
timed_out: bool = False
|
||
|
||
|
||
@runner.runner_class('deerflow-api')
|
||
class DeerFlowAPIRunner(runner.RequestRunner):
|
||
"""DeerFlow LangGraph API 对话请求器"""
|
||
|
||
deerflow_client: client.AsyncDeerFlowClient
|
||
|
||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||
super().__init__(ap, pipeline_config)
|
||
|
||
cfg = self.pipeline_config['ai']['deerflow-api']
|
||
|
||
api_base = cfg.get('api-base', '').strip()
|
||
if not api_base or not api_base.startswith(('http://', 'https://')):
|
||
raise errors.DeerFlowAPIError(
|
||
message='DeerFlow API Base URL 格式错误,必须以 http:// 或 https:// 开头',
|
||
)
|
||
|
||
self.api_base = api_base
|
||
self.api_key = cfg.get('api-key', '')
|
||
self.auth_header = cfg.get('auth-header', '')
|
||
self.assistant_id = cfg.get('assistant-id', 'lead_agent')
|
||
self.model_name = cfg.get('model-name', '')
|
||
self.thinking_enabled = bool(cfg.get('thinking-enabled', False))
|
||
self.plan_mode = bool(cfg.get('plan-mode', False))
|
||
self.subagent_enabled = bool(cfg.get('subagent-enabled', False))
|
||
self.max_concurrent_subagents = int(cfg.get('max-concurrent-subagents', 3))
|
||
self.timeout = int(cfg.get('timeout', 300))
|
||
self.recursion_limit = int(cfg.get('recursion-limit', 1000))
|
||
|
||
self.deerflow_client = client.AsyncDeerFlowClient(
|
||
api_base=self.api_base,
|
||
api_key=self.api_key,
|
||
auth_header=self.auth_header,
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 辅助方法
|
||
# ------------------------------------------------------------------
|
||
|
||
def _fingerprint_message(self, message: dict[str, typing.Any]) -> str:
|
||
try:
|
||
raw = json.dumps(message, sort_keys=True, ensure_ascii=False, default=str)
|
||
except (TypeError, ValueError):
|
||
raw = repr(message)
|
||
return hashlib.sha1(raw.encode('utf-8', errors='ignore')).hexdigest()
|
||
|
||
def _remember_seen_message_id(self, state: _StreamState, msg_id: str) -> None:
|
||
if not msg_id or msg_id in state.seen_message_ids:
|
||
return
|
||
state.seen_message_ids.add(msg_id)
|
||
state.seen_message_order.append(msg_id)
|
||
while len(state.seen_message_order) > _MAX_VALUES_HISTORY:
|
||
dropped = state.seen_message_order.popleft()
|
||
state.seen_message_ids.discard(dropped)
|
||
|
||
def _extract_new_messages_from_values(
|
||
self,
|
||
values_messages: list[typing.Any],
|
||
state: _StreamState,
|
||
) -> list[dict[str, typing.Any]]:
|
||
new_messages: list[dict[str, typing.Any]] = []
|
||
no_id_indexes_seen: set[int] = set()
|
||
for idx, msg in enumerate(values_messages):
|
||
if not isinstance(msg, dict):
|
||
continue
|
||
msg_id = stream_utils.get_message_id(msg)
|
||
if msg_id:
|
||
if msg_id in state.seen_message_ids:
|
||
continue
|
||
self._remember_seen_message_id(state, msg_id)
|
||
new_messages.append(msg)
|
||
continue
|
||
|
||
no_id_indexes_seen.add(idx)
|
||
fp = self._fingerprint_message(msg)
|
||
if state.no_id_message_fingerprints.get(idx) == fp:
|
||
continue
|
||
state.no_id_message_fingerprints[idx] = fp
|
||
new_messages.append(msg)
|
||
|
||
for idx in list(state.no_id_message_fingerprints.keys()):
|
||
if idx not in no_id_indexes_seen:
|
||
state.no_id_message_fingerprints.pop(idx, None)
|
||
return new_messages
|
||
|
||
# ------------------------------------------------------------------
|
||
# 用户输入处理
|
||
# ------------------------------------------------------------------
|
||
|
||
def _build_user_content(
|
||
self,
|
||
prompt: str,
|
||
image_urls: list[str],
|
||
) -> typing.Any:
|
||
"""构建 LangGraph 兼容的 user content(支持多模态)"""
|
||
if not image_urls:
|
||
return prompt
|
||
|
||
content: list[dict[str, typing.Any]] = []
|
||
if prompt:
|
||
content.append({'type': 'text', 'text': prompt})
|
||
for url in image_urls:
|
||
if not isinstance(url, str):
|
||
continue
|
||
url = url.strip()
|
||
if not url:
|
||
continue
|
||
if url.startswith(('http://', 'https://', 'data:')):
|
||
content.append({'type': 'image_url', 'image_url': {'url': url}})
|
||
return content if content else prompt
|
||
|
||
def _preprocess_user_message(
|
||
self,
|
||
query: pipeline_query.Query,
|
||
) -> tuple[str, list[str]]:
|
||
"""提取用户消息的纯文本与图片 URL 列表"""
|
||
plain_text = ''
|
||
image_urls: list[str] = []
|
||
|
||
if isinstance(query.user_message.content, str):
|
||
plain_text = query.user_message.content
|
||
elif isinstance(query.user_message.content, list):
|
||
for ce in query.user_message.content:
|
||
if ce.type == 'text':
|
||
plain_text += ce.text
|
||
elif ce.type == 'image_base64':
|
||
# 转换为 data URI 形式
|
||
b64 = getattr(ce, 'image_base64', '')
|
||
if b64:
|
||
if not b64.startswith('data:'):
|
||
b64 = f'data:image/png;base64,{b64}'
|
||
image_urls.append(b64)
|
||
elif ce.type == 'image_url':
|
||
url = getattr(ce, 'image_url', '')
|
||
if url:
|
||
image_urls.append(url)
|
||
|
||
return plain_text, image_urls
|
||
|
||
# ------------------------------------------------------------------
|
||
# 请求构造
|
||
# ------------------------------------------------------------------
|
||
|
||
def _build_messages(
|
||
self,
|
||
prompt: str,
|
||
image_urls: list[str],
|
||
system_prompt: str = '',
|
||
) -> list[dict[str, typing.Any]]:
|
||
messages: list[dict[str, typing.Any]] = []
|
||
if system_prompt:
|
||
messages.append({'role': 'system', 'content': system_prompt})
|
||
messages.append(
|
||
{
|
||
'role': 'user',
|
||
'content': self._build_user_content(prompt, image_urls),
|
||
}
|
||
)
|
||
return messages
|
||
|
||
def _build_runtime_configurable(self, thread_id: str) -> dict[str, typing.Any]:
|
||
cfg: dict[str, typing.Any] = {
|
||
'thread_id': thread_id,
|
||
'thinking_enabled': self.thinking_enabled,
|
||
'is_plan_mode': self.plan_mode,
|
||
'subagent_enabled': self.subagent_enabled,
|
||
}
|
||
if self.subagent_enabled:
|
||
cfg['max_concurrent_subagents'] = self.max_concurrent_subagents
|
||
if self.model_name:
|
||
cfg['model_name'] = self.model_name
|
||
return cfg
|
||
|
||
def _build_payload(
|
||
self,
|
||
thread_id: str,
|
||
prompt: str,
|
||
image_urls: list[str],
|
||
system_prompt: str = '',
|
||
) -> dict[str, typing.Any]:
|
||
runtime_configurable = self._build_runtime_configurable(thread_id)
|
||
return {
|
||
'assistant_id': self.assistant_id,
|
||
'input': {
|
||
'messages': self._build_messages(prompt, image_urls, system_prompt),
|
||
},
|
||
'stream_mode': ['values', 'messages-tuple', 'custom'],
|
||
# DeerFlow 2.0 从 config.configurable 读取运行时覆盖
|
||
# 同时保留 context 字段做向后兼容
|
||
'context': dict(runtime_configurable),
|
||
'config': {
|
||
'recursion_limit': self.recursion_limit,
|
||
'configurable': runtime_configurable,
|
||
},
|
||
}
|
||
|
||
# ------------------------------------------------------------------
|
||
# Session/Thread 管理
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _ensure_thread_id(self, query: pipeline_query.Query) -> str:
|
||
"""从 query.session 取/创建 deerflow thread_id
|
||
|
||
LangBot 使用 `query.session.using_conversation.uuid` 持久化 conversation id,
|
||
我们复用这个字段存储 deerflow thread_id(与 Dify Runner 同样做法)。
|
||
"""
|
||
thread_id = query.session.using_conversation.uuid or ''
|
||
if thread_id:
|
||
return thread_id
|
||
|
||
thread = await self.deerflow_client.create_thread(timeout=min(30, self.timeout))
|
||
thread_id = thread.get('thread_id', '')
|
||
if not thread_id:
|
||
raise errors.DeerFlowAPIError(
|
||
message=f'DeerFlow create thread 返回数据缺少 thread_id: {thread}'
|
||
)
|
||
|
||
query.session.using_conversation.uuid = thread_id
|
||
return thread_id
|
||
|
||
# ------------------------------------------------------------------
|
||
# 流式事件处理
|
||
# ------------------------------------------------------------------
|
||
|
||
def _handle_values_event(
|
||
self,
|
||
data: typing.Any,
|
||
state: _StreamState,
|
||
) -> str | None:
|
||
"""处理 values 事件,返回新的完整文本(增量基础上的全量)"""
|
||
values_messages = stream_utils.extract_messages_from_values_data(data)
|
||
if not values_messages:
|
||
return None
|
||
|
||
new_messages: list[dict[str, typing.Any]] = []
|
||
if not state.baseline_initialized:
|
||
state.baseline_initialized = True
|
||
for idx, msg in enumerate(values_messages):
|
||
if not isinstance(msg, dict):
|
||
continue
|
||
new_messages.append(msg)
|
||
msg_id = stream_utils.get_message_id(msg)
|
||
if msg_id:
|
||
self._remember_seen_message_id(state, msg_id)
|
||
continue
|
||
state.no_id_message_fingerprints[idx] = self._fingerprint_message(msg)
|
||
else:
|
||
new_messages = self._extract_new_messages_from_values(values_messages, state)
|
||
|
||
latest_text = ''
|
||
if new_messages:
|
||
state.run_values_messages.extend(new_messages)
|
||
if len(state.run_values_messages) > _MAX_VALUES_HISTORY:
|
||
state.run_values_messages = state.run_values_messages[
|
||
-_MAX_VALUES_HISTORY:
|
||
]
|
||
latest_text = stream_utils.extract_latest_ai_text(state.run_values_messages)
|
||
if latest_text:
|
||
state.has_values_text = True
|
||
latest_clarification = stream_utils.extract_latest_clarification_text(
|
||
state.run_values_messages,
|
||
)
|
||
if latest_clarification:
|
||
state.clarification_text = latest_clarification
|
||
|
||
return latest_text or None
|
||
|
||
def _handle_message_event(
|
||
self,
|
||
data: typing.Any,
|
||
state: _StreamState,
|
||
) -> str | None:
|
||
"""处理 messages-tuple 事件,返回增量文本
|
||
|
||
当 values 事件已经提供完整文本时,跳过 messages-tuple 的增量
|
||
"""
|
||
delta = stream_utils.extract_ai_delta_from_event_data(data)
|
||
if delta and not state.has_values_text:
|
||
state.latest_text += delta
|
||
return delta
|
||
|
||
maybe_clar = stream_utils.extract_clarification_from_event_data(data)
|
||
if maybe_clar:
|
||
state.clarification_text = maybe_clar
|
||
return None
|
||
|
||
def _build_final_text(self, state: _StreamState) -> str:
|
||
"""构建最终输出文本"""
|
||
if state.clarification_text:
|
||
return state.clarification_text
|
||
|
||
# 优先使用最后一条 AI message 的文本
|
||
latest_ai = stream_utils.extract_latest_ai_message(state.run_values_messages)
|
||
if latest_ai:
|
||
text = stream_utils.extract_text(latest_ai.get('content'))
|
||
if text:
|
||
if state.timed_out:
|
||
text += (
|
||
f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
|
||
)
|
||
return text
|
||
|
||
if state.latest_text:
|
||
text = state.latest_text
|
||
if state.timed_out:
|
||
text += (
|
||
f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
|
||
)
|
||
return text
|
||
|
||
# 提取任务失败信息作兜底
|
||
failure_text = stream_utils.build_task_failure_summary(state.task_failures)
|
||
if failure_text:
|
||
return failure_text
|
||
|
||
return 'DeerFlow 返回空响应'
|
||
|
||
# ------------------------------------------------------------------
|
||
# 主流程
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _stream_messages_chunk(
|
||
self,
|
||
query: pipeline_query.Query,
|
||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||
"""流式输出生成器"""
|
||
plain_text, image_urls = self._preprocess_user_message(query)
|
||
|
||
system_prompt = ''
|
||
# LangBot 的 pipeline 通常通过 prompt-preprocess 已注入 system prompt
|
||
# 这里保持空,让 prompt-preprocess 的内容作为 user message 一并送给 deerflow
|
||
|
||
thread_id = await self._ensure_thread_id(query)
|
||
payload = self._build_payload(
|
||
thread_id=thread_id,
|
||
prompt=plain_text or 'continue',
|
||
image_urls=image_urls,
|
||
system_prompt=system_prompt,
|
||
)
|
||
|
||
state = _StreamState()
|
||
prev_text = ''
|
||
message_idx = 0
|
||
|
||
try:
|
||
async for event in self.deerflow_client.stream_run(
|
||
thread_id=thread_id,
|
||
payload=payload,
|
||
timeout=self.timeout,
|
||
):
|
||
event_type = event.get('event')
|
||
data = event.get('data')
|
||
|
||
if event_type == 'values':
|
||
new_full = self._handle_values_event(data, state)
|
||
if new_full and new_full != prev_text:
|
||
delta = (
|
||
new_full[len(prev_text):]
|
||
if new_full.startswith(prev_text)
|
||
else new_full
|
||
)
|
||
prev_text = new_full
|
||
if delta:
|
||
message_idx += 1
|
||
yield provider_message.MessageChunk(
|
||
role='assistant',
|
||
content=new_full,
|
||
is_final=False,
|
||
)
|
||
continue
|
||
|
||
if event_type in {'messages-tuple', 'messages', 'message'}:
|
||
delta = self._handle_message_event(data, state)
|
||
if delta:
|
||
prev_text = state.latest_text
|
||
message_idx += 1
|
||
yield provider_message.MessageChunk(
|
||
role='assistant',
|
||
content=prev_text,
|
||
is_final=False,
|
||
)
|
||
continue
|
||
|
||
if event_type == 'custom':
|
||
state.task_failures.extend(
|
||
stream_utils.extract_task_failures_from_custom_event(data),
|
||
)
|
||
continue
|
||
|
||
if event_type == 'error':
|
||
raise errors.DeerFlowAPIError(
|
||
message=f'DeerFlow stream error event: {data}'
|
||
)
|
||
|
||
if event_type == 'end':
|
||
break
|
||
except (asyncio.TimeoutError, TimeoutError):
|
||
self.ap.logger.warning(
|
||
f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}'
|
||
)
|
||
state.timed_out = True
|
||
|
||
# 最终消息
|
||
final_text = self._build_final_text(state)
|
||
yield provider_message.MessageChunk(
|
||
role='assistant',
|
||
content=final_text,
|
||
is_final=True,
|
||
)
|
||
|
||
async def _messages(
|
||
self,
|
||
query: pipeline_query.Query,
|
||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||
"""非流式聚合输出"""
|
||
plain_text, image_urls = self._preprocess_user_message(query)
|
||
|
||
thread_id = await self._ensure_thread_id(query)
|
||
payload = self._build_payload(
|
||
thread_id=thread_id,
|
||
prompt=plain_text or 'continue',
|
||
image_urls=image_urls,
|
||
)
|
||
|
||
state = _StreamState()
|
||
|
||
try:
|
||
async for event in self.deerflow_client.stream_run(
|
||
thread_id=thread_id,
|
||
payload=payload,
|
||
timeout=self.timeout,
|
||
):
|
||
event_type = event.get('event')
|
||
data = event.get('data')
|
||
|
||
if event_type == 'values':
|
||
self._handle_values_event(data, state)
|
||
continue
|
||
|
||
if event_type in {'messages-tuple', 'messages', 'message'}:
|
||
self._handle_message_event(data, state)
|
||
continue
|
||
|
||
if event_type == 'custom':
|
||
state.task_failures.extend(
|
||
stream_utils.extract_task_failures_from_custom_event(data),
|
||
)
|
||
continue
|
||
|
||
if event_type == 'error':
|
||
raise errors.DeerFlowAPIError(
|
||
message=f'DeerFlow stream error event: {data}'
|
||
)
|
||
|
||
if event_type == 'end':
|
||
break
|
||
except (asyncio.TimeoutError, TimeoutError):
|
||
self.ap.logger.warning(
|
||
f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}'
|
||
)
|
||
state.timed_out = True
|
||
|
||
final_text = self._build_final_text(state)
|
||
yield provider_message.Message(
|
||
role='assistant',
|
||
content=final_text,
|
||
)
|
||
|
||
async def run(
|
||
self,
|
||
query: pipeline_query.Query,
|
||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||
"""主入口:根据 adapter 是否支持流式输出,选择流式或非流式"""
|
||
if await query.adapter.is_stream_output_supported():
|
||
msg_idx = 0
|
||
async for msg in self._stream_messages_chunk(query):
|
||
msg_idx += 1
|
||
msg.msg_sequence = msg_idx
|
||
yield msg
|
||
else:
|
||
async for msg in self._messages(query):
|
||
yield msg
|