Merge pull request #2230 from langbot-app/feat/addweknoradeerflow

Add DeerFlow LangGraph API as a Provider Runner
This commit is contained in:
Haoxuan Xing
2026-06-07 12:22:55 +08:00
committed by GitHub
20 changed files with 1788 additions and 8 deletions

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from .. import migration
@migration.migration_class('weknora-api-config', 42)
class WeKnoraAPICfgMigration(migration.Migration):
"""WeKnora API 配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'weknora-api' not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['weknora-api'] = {
'base-url': 'http://localhost:8080/api/v1',
'app-type': 'agent',
'api-key': '',
'agent-id': 'builtin-smart-reasoning',
'knowledge-base-ids': [],
'web-search-enabled': False,
'timeout': 120,
'base-prompt': '请回答用户的问题。',
}
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from .. import migration
@migration.migration_class('deerflow-api-config', 43)
class DeerFlowAPICfgMigration(migration.Migration):
"""DeerFlow API 配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'deerflow-api' not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['deerflow-api'] = {
'api-base': 'http://127.0.0.1:2026',
'api-key': '',
'auth-header': '',
'assistant-id': 'lead_agent',
'model-name': '',
'thinking-enabled': False,
'plan-mode': False,
'subagent-enabled': False,
'max-concurrent-subagents': 3,
'timeout': 300,
'recursion-limit': 1000,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,511 @@
"""DeerFlow LangGraph API Runner
参考 astrbot 的 deerflow_agent_runner 实现,适配 LangBot 的 Runner 接口。
特点:
- 使用 LangGraph HTTP API 接入 deer-flow 后端
- 自动管理 thread_id按 session 隔离)
- 支持 SSE 流式响应解析
- 支持 streaming/非流式两种输出
- 处理 values / messages-tuple / custom 三种事件
"""
from __future__ import annotations
import asyncio
import hashlib
import json
import typing
from collections import deque
from dataclasses import dataclass, field
from langbot.pkg.provider import runner
from langbot.pkg.core import app
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from langbot.libs.deerflow_api import client, errors, stream_utils
_MAX_VALUES_HISTORY = 200
@dataclass
class _StreamState:
"""流式状态跟踪"""
latest_text: str = ''
prev_text_for_streaming: str = ''
clarification_text: str = ''
task_failures: list[str] = field(default_factory=list)
seen_message_ids: set[str] = field(default_factory=set)
seen_message_order: deque[str] = field(default_factory=deque)
no_id_message_fingerprints: dict[int, str] = field(default_factory=dict)
baseline_initialized: bool = False
has_values_text: bool = False
run_values_messages: list[dict[str, typing.Any]] = field(default_factory=list)
timed_out: bool = False
@runner.runner_class('deerflow-api')
class DeerFlowAPIRunner(runner.RequestRunner):
"""DeerFlow LangGraph API 对话请求器"""
deerflow_client: client.AsyncDeerFlowClient
def __init__(self, ap: app.Application, pipeline_config: dict):
super().__init__(ap, pipeline_config)
cfg = self.pipeline_config['ai']['deerflow-api']
api_base = cfg.get('api-base', '').strip()
if not api_base or not api_base.startswith(('http://', 'https://')):
raise errors.DeerFlowAPIError(
message='DeerFlow API Base URL 格式错误,必须以 http:// 或 https:// 开头',
)
self.api_base = api_base
self.api_key = cfg.get('api-key', '')
self.auth_header = cfg.get('auth-header', '')
self.assistant_id = cfg.get('assistant-id', 'lead_agent')
self.model_name = cfg.get('model-name', '')
self.thinking_enabled = bool(cfg.get('thinking-enabled', False))
self.plan_mode = bool(cfg.get('plan-mode', False))
self.subagent_enabled = bool(cfg.get('subagent-enabled', False))
self.max_concurrent_subagents = int(cfg.get('max-concurrent-subagents', 3))
self.timeout = int(cfg.get('timeout', 300))
self.recursion_limit = int(cfg.get('recursion-limit', 1000))
self.deerflow_client = client.AsyncDeerFlowClient(
api_base=self.api_base,
api_key=self.api_key,
auth_header=self.auth_header,
)
# ------------------------------------------------------------------
# 辅助方法
# ------------------------------------------------------------------
def _fingerprint_message(self, message: dict[str, typing.Any]) -> str:
try:
raw = json.dumps(message, sort_keys=True, ensure_ascii=False, default=str)
except (TypeError, ValueError):
raw = repr(message)
return hashlib.sha1(raw.encode('utf-8', errors='ignore')).hexdigest()
def _remember_seen_message_id(self, state: _StreamState, msg_id: str) -> None:
if not msg_id or msg_id in state.seen_message_ids:
return
state.seen_message_ids.add(msg_id)
state.seen_message_order.append(msg_id)
while len(state.seen_message_order) > _MAX_VALUES_HISTORY:
dropped = state.seen_message_order.popleft()
state.seen_message_ids.discard(dropped)
def _extract_new_messages_from_values(
self,
values_messages: list[typing.Any],
state: _StreamState,
) -> list[dict[str, typing.Any]]:
new_messages: list[dict[str, typing.Any]] = []
no_id_indexes_seen: set[int] = set()
for idx, msg in enumerate(values_messages):
if not isinstance(msg, dict):
continue
msg_id = stream_utils.get_message_id(msg)
if msg_id:
if msg_id in state.seen_message_ids:
continue
self._remember_seen_message_id(state, msg_id)
new_messages.append(msg)
continue
no_id_indexes_seen.add(idx)
fp = self._fingerprint_message(msg)
if state.no_id_message_fingerprints.get(idx) == fp:
continue
state.no_id_message_fingerprints[idx] = fp
new_messages.append(msg)
for idx in list(state.no_id_message_fingerprints.keys()):
if idx not in no_id_indexes_seen:
state.no_id_message_fingerprints.pop(idx, None)
return new_messages
# ------------------------------------------------------------------
# 用户输入处理
# ------------------------------------------------------------------
def _build_user_content(
self,
prompt: str,
image_urls: list[str],
) -> typing.Any:
"""构建 LangGraph 兼容的 user content支持多模态"""
if not image_urls:
return prompt
content: list[dict[str, typing.Any]] = []
if prompt:
content.append({'type': 'text', 'text': prompt})
for url in image_urls:
if not isinstance(url, str):
continue
url = url.strip()
if not url:
continue
if url.startswith(('http://', 'https://', 'data:')):
content.append({'type': 'image_url', 'image_url': {'url': url}})
return content if content else prompt
def _preprocess_user_message(
self,
query: pipeline_query.Query,
) -> tuple[str, list[str]]:
"""提取用户消息的纯文本与图片 URL 列表"""
plain_text = ''
image_urls: list[str] = []
if isinstance(query.user_message.content, str):
plain_text = query.user_message.content
elif isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == 'text':
plain_text += ce.text
elif ce.type == 'image_base64':
# 转换为 data URI 形式
b64 = getattr(ce, 'image_base64', '')
if b64:
if not b64.startswith('data:'):
b64 = f'data:image/png;base64,{b64}'
image_urls.append(b64)
elif ce.type == 'image_url':
url = getattr(ce, 'image_url', '')
if url:
image_urls.append(url)
return plain_text, image_urls
# ------------------------------------------------------------------
# 请求构造
# ------------------------------------------------------------------
def _build_messages(
self,
prompt: str,
image_urls: list[str],
system_prompt: str = '',
) -> list[dict[str, typing.Any]]:
messages: list[dict[str, typing.Any]] = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append(
{
'role': 'user',
'content': self._build_user_content(prompt, image_urls),
}
)
return messages
def _build_runtime_configurable(self, thread_id: str) -> dict[str, typing.Any]:
cfg: dict[str, typing.Any] = {
'thread_id': thread_id,
'thinking_enabled': self.thinking_enabled,
'is_plan_mode': self.plan_mode,
'subagent_enabled': self.subagent_enabled,
}
if self.subagent_enabled:
cfg['max_concurrent_subagents'] = self.max_concurrent_subagents
if self.model_name:
cfg['model_name'] = self.model_name
return cfg
def _build_payload(
self,
thread_id: str,
prompt: str,
image_urls: list[str],
system_prompt: str = '',
) -> dict[str, typing.Any]:
runtime_configurable = self._build_runtime_configurable(thread_id)
return {
'assistant_id': self.assistant_id,
'input': {
'messages': self._build_messages(prompt, image_urls, system_prompt),
},
'stream_mode': ['values', 'messages-tuple', 'custom'],
# DeerFlow 2.0 从 config.configurable 读取运行时覆盖
# 同时保留 context 字段做向后兼容
'context': dict(runtime_configurable),
'config': {
'recursion_limit': self.recursion_limit,
'configurable': runtime_configurable,
},
}
# ------------------------------------------------------------------
# Session/Thread 管理
# ------------------------------------------------------------------
async def _ensure_thread_id(self, query: pipeline_query.Query) -> str:
"""从 query.session 取/创建 deerflow thread_id
LangBot 使用 `query.session.using_conversation.uuid` 持久化 conversation id
我们复用这个字段存储 deerflow thread_id与 Dify Runner 同样做法)。
"""
thread_id = query.session.using_conversation.uuid or ''
if thread_id:
return thread_id
thread = await self.deerflow_client.create_thread(timeout=min(30, self.timeout))
thread_id = thread.get('thread_id', '')
if not thread_id:
raise errors.DeerFlowAPIError(message=f'DeerFlow create thread 返回数据缺少 thread_id: {thread}')
query.session.using_conversation.uuid = thread_id
return thread_id
# ------------------------------------------------------------------
# 流式事件处理
# ------------------------------------------------------------------
def _handle_values_event(
self,
data: typing.Any,
state: _StreamState,
) -> str | None:
"""处理 values 事件,返回新的完整文本(增量基础上的全量)"""
values_messages = stream_utils.extract_messages_from_values_data(data)
if not values_messages:
return None
new_messages: list[dict[str, typing.Any]] = []
if not state.baseline_initialized:
state.baseline_initialized = True
for idx, msg in enumerate(values_messages):
if not isinstance(msg, dict):
continue
new_messages.append(msg)
msg_id = stream_utils.get_message_id(msg)
if msg_id:
self._remember_seen_message_id(state, msg_id)
continue
state.no_id_message_fingerprints[idx] = self._fingerprint_message(msg)
else:
new_messages = self._extract_new_messages_from_values(values_messages, state)
latest_text = ''
if new_messages:
state.run_values_messages.extend(new_messages)
if len(state.run_values_messages) > _MAX_VALUES_HISTORY:
state.run_values_messages = state.run_values_messages[-_MAX_VALUES_HISTORY:]
latest_text = stream_utils.extract_latest_ai_text(state.run_values_messages)
if latest_text:
state.has_values_text = True
latest_clarification = stream_utils.extract_latest_clarification_text(
state.run_values_messages,
)
if latest_clarification:
state.clarification_text = latest_clarification
return latest_text or None
def _handle_message_event(
self,
data: typing.Any,
state: _StreamState,
) -> str | None:
"""处理 messages-tuple 事件,返回增量文本
当 values 事件已经提供完整文本时,跳过 messages-tuple 的增量
"""
delta = stream_utils.extract_ai_delta_from_event_data(data)
if delta and not state.has_values_text:
state.latest_text += delta
return delta
maybe_clar = stream_utils.extract_clarification_from_event_data(data)
if maybe_clar:
state.clarification_text = maybe_clar
return None
def _build_final_text(self, state: _StreamState) -> str:
"""构建最终输出文本"""
if state.clarification_text:
return state.clarification_text
# 优先使用最后一条 AI message 的文本
latest_ai = stream_utils.extract_latest_ai_message(state.run_values_messages)
if latest_ai:
text = stream_utils.extract_text(latest_ai.get('content'))
if text:
if state.timed_out:
text += f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
return text
if state.latest_text:
text = state.latest_text
if state.timed_out:
text += f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
return text
# 提取任务失败信息作兜底
failure_text = stream_utils.build_task_failure_summary(state.task_failures)
if failure_text:
return failure_text
return 'DeerFlow 返回空响应'
# ------------------------------------------------------------------
# 主流程
# ------------------------------------------------------------------
async def _stream_messages_chunk(
self,
query: pipeline_query.Query,
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""流式输出生成器"""
plain_text, image_urls = self._preprocess_user_message(query)
system_prompt = ''
# LangBot 的 pipeline 通常通过 prompt-preprocess 已注入 system prompt
# 这里保持空,让 prompt-preprocess 的内容作为 user message 一并送给 deerflow
thread_id = await self._ensure_thread_id(query)
payload = self._build_payload(
thread_id=thread_id,
prompt=plain_text or 'continue',
image_urls=image_urls,
system_prompt=system_prompt,
)
state = _StreamState()
prev_text = ''
message_idx = 0
try:
async for event in self.deerflow_client.stream_run(
thread_id=thread_id,
payload=payload,
timeout=self.timeout,
):
event_type = event.get('event')
data = event.get('data')
if event_type == 'values':
new_full = self._handle_values_event(data, state)
if new_full and new_full != prev_text:
delta = new_full[len(prev_text) :] if new_full.startswith(prev_text) else new_full
prev_text = new_full
if delta:
message_idx += 1
yield provider_message.MessageChunk(
role='assistant',
content=new_full,
is_final=False,
)
continue
if event_type in {'messages-tuple', 'messages', 'message'}:
delta = self._handle_message_event(data, state)
if delta:
prev_text = state.latest_text
message_idx += 1
yield provider_message.MessageChunk(
role='assistant',
content=prev_text,
is_final=False,
)
continue
if event_type == 'custom':
state.task_failures.extend(
stream_utils.extract_task_failures_from_custom_event(data),
)
continue
if event_type == 'error':
raise errors.DeerFlowAPIError(message=f'DeerFlow stream error event: {data}')
if event_type == 'end':
break
except (asyncio.TimeoutError, TimeoutError):
self.ap.logger.warning(f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}')
state.timed_out = True
# 最终消息
final_text = self._build_final_text(state)
yield provider_message.MessageChunk(
role='assistant',
content=final_text,
is_final=True,
)
async def _messages(
self,
query: pipeline_query.Query,
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""非流式聚合输出"""
plain_text, image_urls = self._preprocess_user_message(query)
thread_id = await self._ensure_thread_id(query)
payload = self._build_payload(
thread_id=thread_id,
prompt=plain_text or 'continue',
image_urls=image_urls,
)
state = _StreamState()
try:
async for event in self.deerflow_client.stream_run(
thread_id=thread_id,
payload=payload,
timeout=self.timeout,
):
event_type = event.get('event')
data = event.get('data')
if event_type == 'values':
self._handle_values_event(data, state)
continue
if event_type in {'messages-tuple', 'messages', 'message'}:
self._handle_message_event(data, state)
continue
if event_type == 'custom':
state.task_failures.extend(
stream_utils.extract_task_failures_from_custom_event(data),
)
continue
if event_type == 'error':
raise errors.DeerFlowAPIError(message=f'DeerFlow stream error event: {data}')
if event_type == 'end':
break
except (asyncio.TimeoutError, TimeoutError):
self.ap.logger.warning(f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}')
state.timed_out = True
final_text = self._build_final_text(state)
yield provider_message.Message(
role='assistant',
content=final_text,
)
async def run(
self,
query: pipeline_query.Query,
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""主入口:根据 adapter 是否支持流式输出,选择流式或非流式"""
if await query.adapter.is_stream_output_supported():
msg_idx = 0
async for msg in self._stream_messages_chunk(query):
msg_idx += 1
msg.msg_sequence = msg_idx
yield msg
else:
async for msg in self._messages(query):
yield msg

View File

@@ -0,0 +1,351 @@
from __future__ import annotations
import typing
import json
from langbot.pkg.provider import runner
from langbot.pkg.core import app
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from langbot.libs.weknora_api import client, errors
@runner.runner_class('weknora-api')
class WeKnoraAPIRunner(runner.RequestRunner):
"""WeKnora API 对话请求器"""
weknora_client: client.AsyncWeKnoraClient
def __init__(self, ap: app.Application, pipeline_config: dict):
super().__init__(ap, pipeline_config)
valid_app_types = ['chat', 'agent']
if self.pipeline_config['ai']['weknora-api']['app-type'] not in valid_app_types:
raise errors.WeKnoraAPIError(
f'不支持的 WeKnora 应用类型: {self.pipeline_config["ai"]["weknora-api"]["app-type"]}'
)
api_key = self.pipeline_config['ai']['weknora-api'].get('api-key', '').strip()
if not api_key:
raise errors.WeKnoraAPIError(
'WeKnora API Key 未配置,请在流水线的 WeKnora API 配置中填入 API Key '
'(从 WeKnora 前端 设置 → API Keys 生成)'
)
base_url = self.pipeline_config['ai']['weknora-api'].get('base-url', '').strip()
if not base_url:
raise errors.WeKnoraAPIError('WeKnora Base URL 未配置,请填入服务器地址,例如 http://localhost:8080/api/v1')
self.weknora_client = client.AsyncWeKnoraClient(
api_key=api_key,
base_url=base_url,
)
async def _extract_plain_text(self, query: pipeline_query.Query) -> str:
"""从用户消息中提取纯文本内容"""
plain_text = ''
if isinstance(query.user_message.content, str):
plain_text = query.user_message.content
elif isinstance(query.user_message.content, list):
for ce in query.user_message.content:
if ce.type == 'text':
plain_text += ce.text
if not plain_text:
plain_text = self.pipeline_config['ai']['weknora-api'].get('base-prompt', '')
return plain_text
async def _ensure_session(self, query: pipeline_query.Query) -> str:
"""确保会话存在,如果不存在则创建"""
session_id = query.session.using_conversation.uuid or ''
if not session_id:
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
session_id = await self.weknora_client.create_session(title=f'IM Chat - {user_tag}')
query.session.using_conversation.uuid = session_id
return session_id
async def _agent_chat_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用 Agent 智能对话(非流式聚合输出)"""
session_id = await self._ensure_session(query)
plain_text = await self._extract_plain_text(query)
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
config = self.pipeline_config['ai']['weknora-api']
agent_id = config.get('agent-id', 'builtin-smart-reasoning')
knowledge_base_ids = config.get('knowledge-base-ids', [])
web_search_enabled = config.get('web-search-enabled', False)
timeout = config.get('timeout', 120)
full_answer = ''
chunk = None
async for chunk in self.weknora_client.agent_chat(
session_id=session_id,
query=plain_text,
user=user_tag,
agent_id=agent_id,
knowledge_base_ids=knowledge_base_ids,
web_search_enabled=web_search_enabled,
timeout=timeout,
):
self.ap.logger.debug('weknora-agent-chunk: ' + str(chunk))
response_type = chunk.get('response_type', '')
content = chunk.get('content', '')
if response_type == 'tool_call':
# 工具调用
tool_data = chunk.get('data', {})
tool_name = tool_data.get('tool_name', '')
if tool_name:
yield provider_message.Message(
role='assistant',
tool_calls=[
provider_message.ToolCall(
id=chunk.get('id', ''),
type='function',
function=provider_message.FunctionCall(
name=tool_name,
arguments=json.dumps(tool_data.get('arguments', {})),
),
)
],
)
elif response_type == 'answer':
if content:
full_answer += content
elif response_type == 'error':
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
if chunk is None:
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应请检查网络连接和API配置')
if full_answer:
yield provider_message.Message(
role='assistant',
content=full_answer,
)
async def _chat_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用知识库 RAG 问答(非流式聚合输出)"""
session_id = await self._ensure_session(query)
plain_text = await self._extract_plain_text(query)
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
config = self.pipeline_config['ai']['weknora-api']
agent_id = config.get('agent-id', 'builtin-quick-answer')
knowledge_base_ids = config.get('knowledge-base-ids', [])
timeout = config.get('timeout', 120)
full_answer = ''
chunk = None
async for chunk in self.weknora_client.knowledge_chat(
session_id=session_id,
query=plain_text,
user=user_tag,
agent_id=agent_id,
knowledge_base_ids=knowledge_base_ids,
timeout=timeout,
):
self.ap.logger.debug('weknora-chat-chunk: ' + str(chunk))
response_type = chunk.get('response_type', '')
content = chunk.get('content', '')
if response_type == 'answer':
if content:
full_answer += content
elif response_type == 'error':
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
if chunk is None:
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应请检查网络连接和API配置')
if full_answer:
yield provider_message.Message(
role='assistant',
content=full_answer,
)
async def _agent_chat_messages_chunk(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""调用 Agent 智能对话(流式输出)"""
session_id = await self._ensure_session(query)
plain_text = await self._extract_plain_text(query)
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
config = self.pipeline_config['ai']['weknora-api']
agent_id = config.get('agent-id', 'builtin-smart-reasoning')
knowledge_base_ids = config.get('knowledge-base-ids', [])
web_search_enabled = config.get('web-search-enabled', False)
timeout = config.get('timeout', 120)
pending_answer = ''
message_idx = 0
is_final = False
chunk = None
async for chunk in self.weknora_client.agent_chat(
session_id=session_id,
query=plain_text,
user=user_tag,
agent_id=agent_id,
knowledge_base_ids=knowledge_base_ids,
web_search_enabled=web_search_enabled,
timeout=timeout,
):
self.ap.logger.debug('weknora-agent-chunk: ' + str(chunk))
response_type = chunk.get('response_type', '')
content = chunk.get('content', '')
done = chunk.get('done', False)
if response_type == 'tool_call':
tool_data = chunk.get('data', {})
tool_name = tool_data.get('tool_name', '')
if tool_name:
message_idx += 1
yield provider_message.MessageChunk(
role='assistant',
tool_calls=[
provider_message.ToolCall(
id=chunk.get('id', ''),
type='function',
function=provider_message.FunctionCall(
name=tool_name,
arguments=json.dumps(tool_data.get('arguments', {})),
),
)
],
)
elif response_type == 'answer':
message_idx += 1
if content:
pending_answer += content
if done:
is_final = True
# 每 8 个 chunk 输出一次,或最终输出
if message_idx % 8 == 0 or is_final:
yield provider_message.MessageChunk(
role='assistant',
content=pending_answer,
is_final=is_final,
)
elif response_type == 'error':
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
if chunk is None:
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应请检查网络连接和API配置')
# 确保最终消息已发出
if not is_final and pending_answer:
yield provider_message.MessageChunk(
role='assistant',
content=pending_answer,
is_final=True,
)
async def _chat_messages_chunk(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
"""调用知识库 RAG 问答(流式输出)"""
session_id = await self._ensure_session(query)
plain_text = await self._extract_plain_text(query)
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
config = self.pipeline_config['ai']['weknora-api']
agent_id = config.get('agent-id', 'builtin-quick-answer')
knowledge_base_ids = config.get('knowledge-base-ids', [])
timeout = config.get('timeout', 120)
pending_answer = ''
message_idx = 0
is_final = False
chunk = None
async for chunk in self.weknora_client.knowledge_chat(
session_id=session_id,
query=plain_text,
user=user_tag,
agent_id=agent_id,
knowledge_base_ids=knowledge_base_ids,
timeout=timeout,
):
self.ap.logger.debug('weknora-chat-chunk: ' + str(chunk))
response_type = chunk.get('response_type', '')
content = chunk.get('content', '')
done = chunk.get('done', False)
if response_type == 'answer':
message_idx += 1
if content:
pending_answer += content
if done:
is_final = True
if message_idx % 8 == 0 or is_final:
yield provider_message.MessageChunk(
role='assistant',
content=pending_answer,
is_final=is_final,
)
elif response_type == 'error':
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
if chunk is None:
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应请检查网络连接和API配置')
if not is_final and pending_answer:
yield provider_message.MessageChunk(
role='assistant',
content=pending_answer,
is_final=True,
)
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行请求"""
app_type = self.pipeline_config['ai']['weknora-api']['app-type']
if await query.adapter.is_stream_output_supported():
msg_idx = 0
if app_type == 'agent':
async for msg in self._agent_chat_messages_chunk(query):
msg_idx += 1
msg.msg_sequence = msg_idx
yield msg
elif app_type == 'chat':
async for msg in self._chat_messages_chunk(query):
msg_idx += 1
msg.msg_sequence = msg_idx
yield msg
else:
raise errors.WeKnoraAPIError(f'不支持的 WeKnora 应用类型: {app_type}')
else:
if app_type == 'agent':
async for msg in self._agent_chat_messages(query):
yield msg
elif app_type == 'chat':
async for msg in self._chat_messages(query):
yield msg
else:
raise errors.WeKnoraAPIError(f'不支持的 WeKnora 应用类型: {app_type}')