mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-07 22:36:02 +00:00
205 lines
6.6 KiB
Python
205 lines
6.6 KiB
Python
"""DeerFlow LangGraph HTTP API 客户端
|
||
|
||
参考 astrbot 的 deerflow_api_client 实现,使用 httpx 适配 LangBot 风格。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import codecs
|
||
import json
|
||
import typing
|
||
from collections.abc import AsyncGenerator
|
||
|
||
import httpx
|
||
|
||
from .errors import DeerFlowAPIError
|
||
|
||
|
||
SSE_MAX_BUFFER_CHARS = 1_048_576
|
||
|
||
|
||
def _normalize_sse_newlines(text: str) -> str:
|
||
"""规范化 CRLF/CR 为 LF,确保 SSE 块分割稳定"""
|
||
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||
|
||
|
||
def _parse_sse_data_lines(data_lines: list[str]) -> typing.Any:
|
||
raw_data = '\n'.join(data_lines)
|
||
try:
|
||
return json.loads(raw_data)
|
||
except json.JSONDecodeError:
|
||
# 某些 LangGraph 兼容服务端会在单个 SSE 事件中用多个 data 行
|
||
# 发送多段 JSON 片段(例如 tuple payload)
|
||
parsed_lines: list[typing.Any] = []
|
||
can_parse_all = True
|
||
for line in data_lines:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
parsed_lines.append(json.loads(line))
|
||
except json.JSONDecodeError:
|
||
can_parse_all = False
|
||
break
|
||
if can_parse_all and parsed_lines:
|
||
return parsed_lines[0] if len(parsed_lines) == 1 else parsed_lines
|
||
return raw_data
|
||
|
||
|
||
def _parse_sse_block(block: str) -> dict[str, typing.Any] | None:
|
||
if not block.strip():
|
||
return None
|
||
|
||
event_name = 'message'
|
||
data_lines: list[str] = []
|
||
for line in block.splitlines():
|
||
if line.startswith('event:'):
|
||
event_name = line[6:].strip()
|
||
elif line.startswith('data:'):
|
||
data_lines.append(line[5:].lstrip())
|
||
|
||
if not data_lines:
|
||
return None
|
||
return {'event': event_name, 'data': _parse_sse_data_lines(data_lines)}
|
||
|
||
|
||
class AsyncDeerFlowClient:
|
||
"""DeerFlow LangGraph HTTP API 客户端"""
|
||
|
||
api_base: str
|
||
headers: dict[str, str]
|
||
|
||
def __init__(
|
||
self,
|
||
api_base: str = 'http://127.0.0.1:2026',
|
||
api_key: str = '',
|
||
auth_header: str = '',
|
||
) -> None:
|
||
self.api_base = api_base.rstrip('/')
|
||
self.headers: dict[str, str] = {}
|
||
if auth_header:
|
||
self.headers['Authorization'] = auth_header
|
||
elif api_key:
|
||
self.headers['Authorization'] = f'Bearer {api_key}'
|
||
|
||
async def create_thread(self, timeout: float = 20) -> dict[str, typing.Any]:
|
||
"""创建一个新的 LangGraph thread
|
||
|
||
Returns:
|
||
包含 thread_id 等信息的字典
|
||
"""
|
||
url = f'{self.api_base}/api/langgraph/threads'
|
||
payload = {'metadata': {}}
|
||
|
||
async with httpx.AsyncClient(
|
||
trust_env=True,
|
||
timeout=timeout,
|
||
) as client:
|
||
response = await client.post(
|
||
url,
|
||
headers=self.headers,
|
||
json=payload,
|
||
)
|
||
if response.status_code not in (200, 201):
|
||
raise DeerFlowAPIError(
|
||
operation='create thread',
|
||
status=response.status_code,
|
||
body=response.text,
|
||
url=url,
|
||
)
|
||
return response.json()
|
||
|
||
async def delete_thread(self, thread_id: str, timeout: float = 20) -> None:
|
||
"""删除指定 thread"""
|
||
url = f'{self.api_base}/api/threads/{thread_id}'
|
||
|
||
async with httpx.AsyncClient(
|
||
trust_env=True,
|
||
timeout=timeout,
|
||
) as client:
|
||
response = await client.delete(url, headers=self.headers)
|
||
if response.status_code not in (200, 202, 204, 404):
|
||
raise DeerFlowAPIError(
|
||
operation='delete thread',
|
||
status=response.status_code,
|
||
body=response.text,
|
||
url=url,
|
||
thread_id=thread_id,
|
||
)
|
||
|
||
async def stream_run(
|
||
self,
|
||
thread_id: str,
|
||
payload: dict[str, typing.Any],
|
||
timeout: float = 120,
|
||
) -> AsyncGenerator[dict[str, typing.Any], None]:
|
||
"""运行一次 LangGraph stream 请求,逐事件 yield
|
||
|
||
Yields:
|
||
事件字典 {'event': event_name, 'data': parsed_data}
|
||
"""
|
||
url = f'{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream'
|
||
|
||
# 流式请求使用单独的 read timeout 控制
|
||
stream_timeout = httpx.Timeout(
|
||
connect=min(timeout, 30),
|
||
read=timeout,
|
||
write=timeout,
|
||
pool=timeout,
|
||
)
|
||
|
||
async with httpx.AsyncClient(
|
||
trust_env=True,
|
||
timeout=stream_timeout,
|
||
) as client:
|
||
async with client.stream(
|
||
'POST',
|
||
url,
|
||
headers={
|
||
**self.headers,
|
||
'Accept': 'text/event-stream',
|
||
'Content-Type': 'application/json',
|
||
},
|
||
json=payload,
|
||
) as resp:
|
||
if resp.status_code != 200:
|
||
body = await resp.aread()
|
||
raise DeerFlowAPIError(
|
||
operation='runs/stream request',
|
||
status=resp.status_code,
|
||
body=body.decode('utf-8', errors='replace'),
|
||
url=url,
|
||
thread_id=thread_id,
|
||
)
|
||
|
||
decoder = codecs.getincrementaldecoder('utf-8')('replace')
|
||
buffer = ''
|
||
|
||
async for chunk in resp.aiter_bytes(8192):
|
||
buffer += _normalize_sse_newlines(decoder.decode(chunk))
|
||
|
||
while '\n\n' in buffer:
|
||
block, buffer = buffer.split('\n\n', 1)
|
||
parsed = _parse_sse_block(block)
|
||
if parsed is not None:
|
||
yield parsed
|
||
|
||
if len(buffer) > SSE_MAX_BUFFER_CHARS:
|
||
# 缓冲区过大,强制 flush
|
||
parsed = _parse_sse_block(buffer)
|
||
if parsed is not None:
|
||
yield parsed
|
||
buffer = ''
|
||
|
||
# flush 剩余内容
|
||
buffer += _normalize_sse_newlines(decoder.decode(b'', final=True))
|
||
while '\n\n' in buffer:
|
||
block, buffer = buffer.split('\n\n', 1)
|
||
parsed = _parse_sse_block(block)
|
||
if parsed is not None:
|
||
yield parsed
|
||
if buffer.strip():
|
||
parsed = _parse_sse_block(buffer)
|
||
if parsed is not None:
|
||
yield parsed
|