This commit is contained in:
Typer_Body
2026-06-07 02:38:05 +08:00
parent fd896c6974
commit 07b90f12a2
5 changed files with 19 additions and 48 deletions

View File

@@ -2,6 +2,7 @@
参考 astrbot 的 deerflow_api_client 实现,使用 httpx 适配 LangBot 风格。
"""
from __future__ import annotations
import codecs

View File

@@ -26,8 +26,5 @@ class DeerFlowAPIError(Exception):
msg = f'DeerFlow {operation} failed: status={status}, url={url}, body={body}'
if thread_id is not None:
msg = (
f'DeerFlow {operation} failed: thread_id={thread_id}, '
f'status={status}, url={url}, body={body}'
)
msg = f'DeerFlow {operation} failed: thread_id={thread_id}, status={status}, url={url}, body={body}'
super().__init__(msg)

View File

@@ -2,6 +2,7 @@
参考 astrbot 实现的 deerflow_stream_utils。
"""
from __future__ import annotations
import typing
@@ -60,9 +61,7 @@ def is_ai_message(message: dict[str, typing.Any]) -> bool:
msg_type = str(message.get('type', '')).lower()
if msg_type in {'ai', 'assistant', 'aimessage', 'aimessagechunk'}:
return True
if 'ai' in msg_type and all(
token not in msg_type for token in ('human', 'tool', 'system')
):
if 'ai' in msg_type and all(token not in msg_type for token in ('human', 'tool', 'system')):
return True
return False

View File

@@ -9,6 +9,7 @@
- 支持 streaming/非流式两种输出
- 处理 values / messages-tuple / custom 三种事件
"""
from __future__ import annotations
import asyncio
@@ -32,6 +33,7 @@ _MAX_VALUES_HISTORY = 200
@dataclass
class _StreamState:
"""流式状态跟踪"""
latest_text: str = ''
prev_text_for_streaming: str = ''
clarification_text: str = ''
@@ -258,9 +260,7 @@ class DeerFlowAPIRunner(runner.RequestRunner):
thread = await self.deerflow_client.create_thread(timeout=min(30, self.timeout))
thread_id = thread.get('thread_id', '')
if not thread_id:
raise errors.DeerFlowAPIError(
message=f'DeerFlow create thread 返回数据缺少 thread_id: {thread}'
)
raise errors.DeerFlowAPIError(message=f'DeerFlow create thread 返回数据缺少 thread_id: {thread}')
query.session.using_conversation.uuid = thread_id
return thread_id
@@ -298,9 +298,7 @@ class DeerFlowAPIRunner(runner.RequestRunner):
if new_messages:
state.run_values_messages.extend(new_messages)
if len(state.run_values_messages) > _MAX_VALUES_HISTORY:
state.run_values_messages = state.run_values_messages[
-_MAX_VALUES_HISTORY:
]
state.run_values_messages = state.run_values_messages[-_MAX_VALUES_HISTORY:]
latest_text = stream_utils.extract_latest_ai_text(state.run_values_messages)
if latest_text:
state.has_values_text = True
@@ -342,17 +340,13 @@ class DeerFlowAPIRunner(runner.RequestRunner):
text = stream_utils.extract_text(latest_ai.get('content'))
if text:
if state.timed_out:
text += (
f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
)
text += f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
return text
if state.latest_text:
text = state.latest_text
if state.timed_out:
text += (
f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
)
text += f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
return text
# 提取任务失败信息作兜底
@@ -401,11 +395,7 @@ class DeerFlowAPIRunner(runner.RequestRunner):
if event_type == 'values':
new_full = self._handle_values_event(data, state)
if new_full and new_full != prev_text:
delta = (
new_full[len(prev_text):]
if new_full.startswith(prev_text)
else new_full
)
delta = new_full[len(prev_text) :] if new_full.startswith(prev_text) else new_full
prev_text = new_full
if delta:
message_idx += 1
@@ -435,16 +425,12 @@ class DeerFlowAPIRunner(runner.RequestRunner):
continue
if event_type == 'error':
raise errors.DeerFlowAPIError(
message=f'DeerFlow stream error event: {data}'
)
raise errors.DeerFlowAPIError(message=f'DeerFlow stream error event: {data}')
if event_type == 'end':
break
except (asyncio.TimeoutError, TimeoutError):
self.ap.logger.warning(
f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}'
)
self.ap.logger.warning(f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}')
state.timed_out = True
# 最终消息
@@ -495,16 +481,12 @@ class DeerFlowAPIRunner(runner.RequestRunner):
continue
if event_type == 'error':
raise errors.DeerFlowAPIError(
message=f'DeerFlow stream error event: {data}'
)
raise errors.DeerFlowAPIError(message=f'DeerFlow stream error event: {data}')
if event_type == 'end':
break
except (asyncio.TimeoutError, TimeoutError):
self.ap.logger.warning(
f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}'
)
self.ap.logger.warning(f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}')
state.timed_out = True
final_text = self._build_final_text(state)

View File

@@ -35,9 +35,7 @@ class WeKnoraAPIRunner(runner.RequestRunner):
base_url = self.pipeline_config['ai']['weknora-api'].get('base-url', '').strip()
if not base_url:
raise errors.WeKnoraAPIError(
'WeKnora Base URL 未配置,请填入服务器地址,例如 http://localhost:8080/api/v1'
)
raise errors.WeKnoraAPIError('WeKnora Base URL 未配置,请填入服务器地址,例如 http://localhost:8080/api/v1')
self.weknora_client = client.AsyncWeKnoraClient(
api_key=api_key,
@@ -65,9 +63,7 @@ class WeKnoraAPIRunner(runner.RequestRunner):
if not session_id:
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
session_id = await self.weknora_client.create_session(
title=f'IM Chat - {user_tag}'
)
session_id = await self.weknora_client.create_session(title=f'IM Chat - {user_tag}')
query.session.using_conversation.uuid = session_id
return session_id
@@ -343,9 +339,7 @@ class WeKnoraAPIRunner(runner.RequestRunner):
msg.msg_sequence = msg_idx
yield msg
else:
raise errors.WeKnoraAPIError(
f'不支持的 WeKnora 应用类型: {app_type}'
)
raise errors.WeKnoraAPIError(f'不支持的 WeKnora 应用类型: {app_type}')
else:
if app_type == 'agent':
async for msg in self._agent_chat_messages(query):
@@ -354,6 +348,4 @@ class WeKnoraAPIRunner(runner.RequestRunner):
async for msg in self._chat_messages(query):
yield msg
else:
raise errors.WeKnoraAPIError(
f'不支持的 WeKnora 应用类型: {app_type}'
)
raise errors.WeKnoraAPIError(f'不支持的 WeKnora 应用类型: {app_type}')