mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-13 01:06:03 +00:00
deerflow
This commit is contained in:
5
src/langbot/libs/deerflow_api/__init__.py
Normal file
5
src/langbot/libs/deerflow_api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .client import AsyncDeerFlowClient
|
||||
from .errors import DeerFlowAPIError
|
||||
from . import stream_utils
|
||||
|
||||
__all__ = ['AsyncDeerFlowClient', 'DeerFlowAPIError', 'stream_utils']
|
||||
203
src/langbot/libs/deerflow_api/client.py
Normal file
203
src/langbot/libs/deerflow_api/client.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""DeerFlow LangGraph HTTP API 客户端
|
||||
|
||||
参考 astrbot 的 deerflow_api_client 实现,使用 httpx 适配 LangBot 风格。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import json
|
||||
import typing
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from .errors import DeerFlowAPIError
|
||||
|
||||
|
||||
SSE_MAX_BUFFER_CHARS = 1_048_576
|
||||
|
||||
|
||||
def _normalize_sse_newlines(text: str) -> str:
|
||||
"""规范化 CRLF/CR 为 LF,确保 SSE 块分割稳定"""
|
||||
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
|
||||
def _parse_sse_data_lines(data_lines: list[str]) -> typing.Any:
|
||||
raw_data = '\n'.join(data_lines)
|
||||
try:
|
||||
return json.loads(raw_data)
|
||||
except json.JSONDecodeError:
|
||||
# 某些 LangGraph 兼容服务端会在单个 SSE 事件中用多个 data 行
|
||||
# 发送多段 JSON 片段(例如 tuple payload)
|
||||
parsed_lines: list[typing.Any] = []
|
||||
can_parse_all = True
|
||||
for line in data_lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parsed_lines.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
can_parse_all = False
|
||||
break
|
||||
if can_parse_all and parsed_lines:
|
||||
return parsed_lines[0] if len(parsed_lines) == 1 else parsed_lines
|
||||
return raw_data
|
||||
|
||||
|
||||
def _parse_sse_block(block: str) -> dict[str, typing.Any] | None:
|
||||
if not block.strip():
|
||||
return None
|
||||
|
||||
event_name = 'message'
|
||||
data_lines: list[str] = []
|
||||
for line in block.splitlines():
|
||||
if line.startswith('event:'):
|
||||
event_name = line[6:].strip()
|
||||
elif line.startswith('data:'):
|
||||
data_lines.append(line[5:].lstrip())
|
||||
|
||||
if not data_lines:
|
||||
return None
|
||||
return {'event': event_name, 'data': _parse_sse_data_lines(data_lines)}
|
||||
|
||||
|
||||
class AsyncDeerFlowClient:
|
||||
"""DeerFlow LangGraph HTTP API 客户端"""
|
||||
|
||||
api_base: str
|
||||
headers: dict[str, str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_base: str = 'http://127.0.0.1:2026',
|
||||
api_key: str = '',
|
||||
auth_header: str = '',
|
||||
) -> None:
|
||||
self.api_base = api_base.rstrip('/')
|
||||
self.headers: dict[str, str] = {}
|
||||
if auth_header:
|
||||
self.headers['Authorization'] = auth_header
|
||||
elif api_key:
|
||||
self.headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
async def create_thread(self, timeout: float = 20) -> dict[str, typing.Any]:
|
||||
"""创建一个新的 LangGraph thread
|
||||
|
||||
Returns:
|
||||
包含 thread_id 等信息的字典
|
||||
"""
|
||||
url = f'{self.api_base}/api/langgraph/threads'
|
||||
payload = {'metadata': {}}
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
trust_env=True,
|
||||
timeout=timeout,
|
||||
) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
)
|
||||
if response.status_code not in (200, 201):
|
||||
raise DeerFlowAPIError(
|
||||
operation='create thread',
|
||||
status=response.status_code,
|
||||
body=response.text,
|
||||
url=url,
|
||||
)
|
||||
return response.json()
|
||||
|
||||
async def delete_thread(self, thread_id: str, timeout: float = 20) -> None:
|
||||
"""删除指定 thread"""
|
||||
url = f'{self.api_base}/api/threads/{thread_id}'
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
trust_env=True,
|
||||
timeout=timeout,
|
||||
) as client:
|
||||
response = await client.delete(url, headers=self.headers)
|
||||
if response.status_code not in (200, 202, 204, 404):
|
||||
raise DeerFlowAPIError(
|
||||
operation='delete thread',
|
||||
status=response.status_code,
|
||||
body=response.text,
|
||||
url=url,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
async def stream_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
payload: dict[str, typing.Any],
|
||||
timeout: float = 120,
|
||||
) -> AsyncGenerator[dict[str, typing.Any], None]:
|
||||
"""运行一次 LangGraph stream 请求,逐事件 yield
|
||||
|
||||
Yields:
|
||||
事件字典 {'event': event_name, 'data': parsed_data}
|
||||
"""
|
||||
url = f'{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream'
|
||||
|
||||
# 流式请求使用单独的 read timeout 控制
|
||||
stream_timeout = httpx.Timeout(
|
||||
connect=min(timeout, 30),
|
||||
read=timeout,
|
||||
write=timeout,
|
||||
pool=timeout,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
trust_env=True,
|
||||
timeout=stream_timeout,
|
||||
) as client:
|
||||
async with client.stream(
|
||||
'POST',
|
||||
url,
|
||||
headers={
|
||||
**self.headers,
|
||||
'Accept': 'text/event-stream',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
json=payload,
|
||||
) as resp:
|
||||
if resp.status_code != 200:
|
||||
body = await resp.aread()
|
||||
raise DeerFlowAPIError(
|
||||
operation='runs/stream request',
|
||||
status=resp.status_code,
|
||||
body=body.decode('utf-8', errors='replace'),
|
||||
url=url,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
decoder = codecs.getincrementaldecoder('utf-8')('replace')
|
||||
buffer = ''
|
||||
|
||||
async for chunk in resp.aiter_bytes(8192):
|
||||
buffer += _normalize_sse_newlines(decoder.decode(chunk))
|
||||
|
||||
while '\n\n' in buffer:
|
||||
block, buffer = buffer.split('\n\n', 1)
|
||||
parsed = _parse_sse_block(block)
|
||||
if parsed is not None:
|
||||
yield parsed
|
||||
|
||||
if len(buffer) > SSE_MAX_BUFFER_CHARS:
|
||||
# 缓冲区过大,强制 flush
|
||||
parsed = _parse_sse_block(buffer)
|
||||
if parsed is not None:
|
||||
yield parsed
|
||||
buffer = ''
|
||||
|
||||
# flush 剩余内容
|
||||
buffer += _normalize_sse_newlines(decoder.decode(b'', final=True))
|
||||
while '\n\n' in buffer:
|
||||
block, buffer = buffer.split('\n\n', 1)
|
||||
parsed = _parse_sse_block(block)
|
||||
if parsed is not None:
|
||||
yield parsed
|
||||
if buffer.strip():
|
||||
parsed = _parse_sse_block(buffer)
|
||||
if parsed is not None:
|
||||
yield parsed
|
||||
33
src/langbot/libs/deerflow_api/errors.py
Normal file
33
src/langbot/libs/deerflow_api/errors.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class DeerFlowAPIError(Exception):
|
||||
"""DeerFlow API 请求失败"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
operation: str = '',
|
||||
status: int = 0,
|
||||
body: str = '',
|
||||
url: str = '',
|
||||
thread_id: str | None = None,
|
||||
message: str = '',
|
||||
) -> None:
|
||||
self.operation = operation
|
||||
self.status = status
|
||||
self.body = body
|
||||
self.url = url
|
||||
self.thread_id = thread_id
|
||||
|
||||
if message:
|
||||
super().__init__(message)
|
||||
return
|
||||
|
||||
msg = f'DeerFlow {operation} failed: status={status}, url={url}, body={body}'
|
||||
if thread_id is not None:
|
||||
msg = (
|
||||
f'DeerFlow {operation} failed: thread_id={thread_id}, '
|
||||
f'status={status}, url={url}, body={body}'
|
||||
)
|
||||
super().__init__(msg)
|
||||
213
src/langbot/libs/deerflow_api/stream_utils.py
Normal file
213
src/langbot/libs/deerflow_api/stream_utils.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""DeerFlow LangGraph 流式响应解析工具
|
||||
|
||||
参考 astrbot 实现的 deerflow_stream_utils。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
def extract_text(content: typing.Any) -> str:
|
||||
"""从消息 content 中提取纯文本"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, dict):
|
||||
if isinstance(content.get('text'), str):
|
||||
return content['text']
|
||||
if 'content' in content:
|
||||
return extract_text(content.get('content'))
|
||||
if 'kwargs' in content and isinstance(content['kwargs'], dict):
|
||||
return extract_text(content['kwargs'].get('content'))
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get('type')
|
||||
if item_type == 'text' and isinstance(item.get('text'), str):
|
||||
parts.append(item['text'])
|
||||
elif 'content' in item:
|
||||
parts.append(extract_text(item['content']))
|
||||
return '\n'.join([p for p in parts if p]).strip()
|
||||
return str(content) if content is not None else ''
|
||||
|
||||
|
||||
def extract_messages_from_values_data(data: typing.Any) -> list[typing.Any]:
|
||||
"""从 values 事件中提取 messages 列表"""
|
||||
candidates: list[typing.Any] = []
|
||||
if isinstance(data, dict):
|
||||
candidates.append(data)
|
||||
if isinstance(data.get('values'), dict):
|
||||
candidates.append(data['values'])
|
||||
elif isinstance(data, list):
|
||||
candidates.extend([x for x in data if isinstance(x, dict)])
|
||||
|
||||
for item in candidates:
|
||||
messages = item.get('messages')
|
||||
if isinstance(messages, list):
|
||||
return messages
|
||||
return []
|
||||
|
||||
|
||||
def is_ai_message(message: dict[str, typing.Any]) -> bool:
|
||||
"""判断是否为 AI/assistant 消息"""
|
||||
role = str(message.get('role', '')).lower()
|
||||
if role in {'assistant', 'ai'}:
|
||||
return True
|
||||
|
||||
msg_type = str(message.get('type', '')).lower()
|
||||
if msg_type in {'ai', 'assistant', 'aimessage', 'aimessagechunk'}:
|
||||
return True
|
||||
if 'ai' in msg_type and all(
|
||||
token not in msg_type for token in ('human', 'tool', 'system')
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_latest_ai_text(messages: Iterable[typing.Any]) -> str:
|
||||
"""获取最近一条 AI 消息的文本内容"""
|
||||
if isinstance(messages, (list, tuple)):
|
||||
iterable = reversed(messages)
|
||||
else:
|
||||
iterable = reversed(list(messages))
|
||||
|
||||
for msg in iterable:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if is_ai_message(msg):
|
||||
text = extract_text(msg.get('content'))
|
||||
if text:
|
||||
return text
|
||||
return ''
|
||||
|
||||
|
||||
def extract_latest_ai_message(messages: Iterable[typing.Any]) -> dict[str, typing.Any] | None:
|
||||
"""获取最近一条 AI 消息对象"""
|
||||
if isinstance(messages, (list, tuple)):
|
||||
iterable = reversed(messages)
|
||||
else:
|
||||
iterable = reversed(list(messages))
|
||||
|
||||
for msg in iterable:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if is_ai_message(msg):
|
||||
return msg
|
||||
return None
|
||||
|
||||
|
||||
def is_clarification_tool_message(message: dict[str, typing.Any]) -> bool:
|
||||
"""判断是否为澄清问题工具消息"""
|
||||
msg_type = str(message.get('type', '')).lower()
|
||||
tool_name = str(message.get('name', '')).lower()
|
||||
return msg_type == 'tool' and tool_name == 'ask_clarification'
|
||||
|
||||
|
||||
def extract_latest_clarification_text(messages: Iterable[typing.Any]) -> str:
|
||||
"""提取最近的澄清问题文本"""
|
||||
if isinstance(messages, (list, tuple)):
|
||||
iterable = reversed(messages)
|
||||
else:
|
||||
iterable = reversed(list(messages))
|
||||
|
||||
for msg in iterable:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if is_clarification_tool_message(msg):
|
||||
text = extract_text(msg.get('content'))
|
||||
if text:
|
||||
return text
|
||||
return ''
|
||||
|
||||
|
||||
def get_message_id(message: typing.Any) -> str:
|
||||
"""提取消息 ID"""
|
||||
if not isinstance(message, dict):
|
||||
return ''
|
||||
msg_id = message.get('id')
|
||||
return msg_id if isinstance(msg_id, str) else ''
|
||||
|
||||
|
||||
def extract_event_message_obj(data: typing.Any) -> dict[str, typing.Any] | None:
|
||||
"""从事件 data 中提取消息对象"""
|
||||
msg_obj = data
|
||||
if isinstance(data, (list, tuple)) and data:
|
||||
msg_obj = data[0]
|
||||
if isinstance(msg_obj, dict) and isinstance(msg_obj.get('data'), dict):
|
||||
msg_obj = msg_obj['data']
|
||||
return msg_obj if isinstance(msg_obj, dict) else None
|
||||
|
||||
|
||||
def extract_ai_delta_from_event_data(data: typing.Any) -> str:
|
||||
"""从 messages-tuple 事件中提取 AI delta 文本"""
|
||||
msg_obj = extract_event_message_obj(data)
|
||||
if not msg_obj:
|
||||
return ''
|
||||
if is_ai_message(msg_obj):
|
||||
return extract_text(msg_obj.get('content'))
|
||||
return ''
|
||||
|
||||
|
||||
def extract_clarification_from_event_data(data: typing.Any) -> str:
|
||||
"""从事件中提取澄清问题"""
|
||||
msg_obj = extract_event_message_obj(data)
|
||||
if not msg_obj:
|
||||
return ''
|
||||
if is_clarification_tool_message(msg_obj):
|
||||
return extract_text(msg_obj.get('content'))
|
||||
return ''
|
||||
|
||||
|
||||
def _iter_custom_event_items(data: typing.Any) -> list[dict[str, typing.Any]]:
|
||||
items: list[dict[str, typing.Any]] = []
|
||||
if isinstance(data, dict):
|
||||
return [data]
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if isinstance(item, dict):
|
||||
items.append(item)
|
||||
elif isinstance(item, (list, tuple)):
|
||||
for nested in item:
|
||||
if isinstance(nested, dict):
|
||||
items.append(nested)
|
||||
return items
|
||||
|
||||
|
||||
def extract_task_failures_from_custom_event(data: typing.Any) -> list[str]:
|
||||
"""从 custom 事件中提取子任务失败信息"""
|
||||
failures: list[str] = []
|
||||
for item in _iter_custom_event_items(data):
|
||||
event_type = str(item.get('type', '')).lower()
|
||||
if event_type not in {'task_failed', 'task_timed_out'}:
|
||||
continue
|
||||
|
||||
task_id = str(item.get('task_id', '')).strip()
|
||||
error_text = extract_text(item.get('error')).strip()
|
||||
if task_id and error_text:
|
||||
failures.append(f'{task_id}: {error_text}')
|
||||
elif error_text:
|
||||
failures.append(error_text)
|
||||
elif task_id:
|
||||
failures.append(f'{task_id}: unknown error')
|
||||
else:
|
||||
failures.append('unknown task failure')
|
||||
return failures
|
||||
|
||||
|
||||
def build_task_failure_summary(failures: list[str]) -> str:
|
||||
"""构建任务失败摘要"""
|
||||
if not failures:
|
||||
return ''
|
||||
deduped: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for failure in failures:
|
||||
if failure not in seen:
|
||||
seen.add(failure)
|
||||
deduped.append(failure)
|
||||
if len(deduped) == 1:
|
||||
return f'DeerFlow subtask failed: {deduped[0]}'
|
||||
joined = '\n'.join([f'- {item}' for item in deduped[:5]])
|
||||
return f'DeerFlow subtasks failed:\n{joined}'
|
||||
Reference in New Issue
Block a user