mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-13 17:26:04 +00:00
fix: harden agent runner runtime boundaries
This commit is contained in:
@@ -16,6 +16,7 @@ from ...entity.persistence.artifact import AgentArtifact
|
||||
from ...entity.persistence.bstorage import BinaryStorage
|
||||
|
||||
_FILE_ARTIFACT_METADATA_KEY = '_langbot_file_artifact'
|
||||
_ARTIFACT_THREAD_METADATA_KEY = '_langbot_thread_id'
|
||||
|
||||
|
||||
class ArtifactStore:
|
||||
@@ -56,6 +57,7 @@ class ArtifactStore:
|
||||
runner_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
expires_at: datetime.datetime | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
) -> str:
|
||||
@@ -92,6 +94,7 @@ class ArtifactStore:
|
||||
runner_id=runner_id,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
thread_id=thread_id,
|
||||
expires_at=expires_at,
|
||||
metadata=public_metadata,
|
||||
content=None,
|
||||
@@ -113,6 +116,7 @@ class ArtifactStore:
|
||||
runner_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
expires_at: datetime.datetime | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
content: bytes | None = None,
|
||||
@@ -137,6 +141,7 @@ class ArtifactStore:
|
||||
runner_id: Runner ID that created this
|
||||
bot_id: Bot UUID
|
||||
workspace_id: Workspace ID
|
||||
thread_id: Thread ID stored as Host-only metadata
|
||||
expires_at: Expiration time
|
||||
metadata: Additional metadata
|
||||
content: Optional content to store in BinaryStorage
|
||||
@@ -147,6 +152,10 @@ class ArtifactStore:
|
||||
if artifact_id is None:
|
||||
artifact_id = str(uuid.uuid4())
|
||||
|
||||
metadata_payload = dict(metadata or {})
|
||||
if thread_id is not None:
|
||||
metadata_payload[_ARTIFACT_THREAD_METADATA_KEY] = thread_id
|
||||
|
||||
# If content provided, store in BinaryStorage
|
||||
if content is not None and storage_key is None:
|
||||
storage_key = f"artifact:{artifact_id}"
|
||||
@@ -184,7 +193,7 @@ class ArtifactStore:
|
||||
workspace_id=workspace_id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
expires_at=expires_at,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
metadata_json=json.dumps(metadata_payload) if metadata_payload else None,
|
||||
)
|
||||
session.add(artifact)
|
||||
await session.commit()
|
||||
@@ -216,6 +225,22 @@ class ArtifactStore:
|
||||
return None
|
||||
return self._row_to_public_dict(row)
|
||||
|
||||
async def get_authorization_metadata(
|
||||
self,
|
||||
artifact_id: str,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Get artifact metadata with Host-only scope fields for authorization."""
|
||||
row = await self._get_internal_record(artifact_id)
|
||||
if row is None:
|
||||
return None
|
||||
metadata = self._row_to_public_dict(row)
|
||||
metadata.update({
|
||||
'bot_id': row.bot_id,
|
||||
'workspace_id': row.workspace_id,
|
||||
'thread_id': self._load_metadata(row.metadata_json).get(_ARTIFACT_THREAD_METADATA_KEY),
|
||||
})
|
||||
return metadata
|
||||
|
||||
async def _get_internal_record(
|
||||
self,
|
||||
artifact_id: str,
|
||||
@@ -455,6 +480,7 @@ class ArtifactStore:
|
||||
def _public_metadata(metadata_json: str | None) -> dict[str, typing.Any]:
|
||||
metadata = ArtifactStore._load_metadata(metadata_json)
|
||||
metadata.pop(_FILE_ARTIFACT_METADATA_KEY, None)
|
||||
metadata.pop(_ARTIFACT_THREAD_METADATA_KEY, None)
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -6,6 +6,10 @@ import typing
|
||||
from .descriptor import AgentRunnerDescriptor
|
||||
|
||||
|
||||
FORM_ITEM_TYPE_ALIASES = {
|
||||
'select-llm-model': 'llm-model-selector',
|
||||
'select-knowledge-bases': 'knowledge-base-multi-selector',
|
||||
}
|
||||
LLM_MODEL_SELECTOR_TYPES = {'model-fallback-selector', 'llm-model-selector'}
|
||||
KB_SELECTOR_TYPES = {'knowledge-base-multi-selector'}
|
||||
PROMPT_EDITOR_TYPES = {'prompt-editor'}
|
||||
@@ -13,6 +17,13 @@ FILE_SELECTOR_TYPES = {'file', 'array[file]'}
|
||||
NONE_SENTINELS = {'', '__none__', '__none'}
|
||||
|
||||
|
||||
def normalize_schema_item_type(item_type: typing.Any) -> typing.Any:
|
||||
"""Normalize legacy/frontend DynamicForm aliases to protocol field types."""
|
||||
if not isinstance(item_type, str):
|
||||
return item_type
|
||||
return FORM_ITEM_TYPE_ALIASES.get(item_type, item_type)
|
||||
|
||||
|
||||
def iter_schema_items(
|
||||
descriptor: AgentRunnerDescriptor | None,
|
||||
field_types: set[str],
|
||||
@@ -23,7 +34,7 @@ def iter_schema_items(
|
||||
for item in descriptor.config_schema or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get('type') in field_types:
|
||||
if normalize_schema_item_type(item.get('type')) in field_types:
|
||||
yield item
|
||||
|
||||
|
||||
@@ -81,7 +92,8 @@ def extract_model_selection(
|
||||
continue
|
||||
|
||||
value = runner_config.get(field_name, item.get('default'))
|
||||
if item.get('type') == 'model-fallback-selector':
|
||||
item_type = normalize_schema_item_type(item.get('type'))
|
||||
if item_type == 'model-fallback-selector':
|
||||
if isinstance(value, str):
|
||||
primary_uuid = value
|
||||
elif isinstance(value, dict):
|
||||
@@ -91,7 +103,7 @@ def extract_model_selection(
|
||||
fallback_uuids = [fallback for fallback in fallbacks if isinstance(fallback, str)]
|
||||
break
|
||||
|
||||
if item.get('type') == 'llm-model-selector' and isinstance(value, str):
|
||||
if item_type == 'llm-model-selector' and isinstance(value, str):
|
||||
primary_uuid = value
|
||||
break
|
||||
|
||||
@@ -145,7 +157,8 @@ def extract_config_file_resources(
|
||||
if not field_name:
|
||||
continue
|
||||
value = runner_config.get(field_name, item.get('default'))
|
||||
if item.get('type') == 'file':
|
||||
item_type = normalize_schema_item_type(item.get('type'))
|
||||
if item_type == 'file':
|
||||
append_file(value)
|
||||
elif isinstance(value, list):
|
||||
for entry in value:
|
||||
@@ -167,7 +180,7 @@ def iter_config_model_refs(
|
||||
continue
|
||||
|
||||
field_name = item.get('name')
|
||||
field_type = item.get('type')
|
||||
field_type = normalize_schema_item_type(item.get('type'))
|
||||
if not field_name or field_name not in runner_config:
|
||||
continue
|
||||
|
||||
@@ -200,7 +213,7 @@ def set_empty_llm_model_selection(
|
||||
"""Set the first empty schema-defined LLM selector to model_uuid."""
|
||||
for item in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES):
|
||||
field_name = item.get('name')
|
||||
field_type = item.get('type')
|
||||
field_type = normalize_schema_item_type(item.get('type'))
|
||||
if not field_name:
|
||||
continue
|
||||
|
||||
|
||||
@@ -283,6 +283,9 @@ class AgentRunOrchestrator:
|
||||
target_run_id = await self._session_registry.find_steering_target(
|
||||
conversation_id=event.conversation_id,
|
||||
runner_id=descriptor.id,
|
||||
bot_id=event.bot_id,
|
||||
workspace_id=event.workspace_id,
|
||||
thread_id=event.thread_id,
|
||||
)
|
||||
if target_run_id is None:
|
||||
return False
|
||||
|
||||
@@ -230,6 +230,7 @@ class AgentRunJournal:
|
||||
runner_id=runner_id,
|
||||
bot_id=event.bot_id,
|
||||
workspace_id=event.workspace_id,
|
||||
thread_id=event.thread_id,
|
||||
metadata=metadata,
|
||||
content=content,
|
||||
)
|
||||
@@ -371,6 +372,7 @@ class AgentRunJournal:
|
||||
runner_id=runner_id,
|
||||
bot_id=event.bot_id,
|
||||
workspace_id=event.workspace_id,
|
||||
thread_id=event.thread_id,
|
||||
metadata=metadata,
|
||||
content=content,
|
||||
)
|
||||
|
||||
@@ -254,6 +254,9 @@ class AgentRunSessionRegistry:
|
||||
*,
|
||||
conversation_id: str,
|
||||
runner_id: str,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""Find the oldest active run that can accept steering for a conversation."""
|
||||
async with self._lock:
|
||||
@@ -264,6 +267,12 @@ class AgentRunSessionRegistry:
|
||||
continue
|
||||
if authorization.get('conversation_id') != conversation_id:
|
||||
continue
|
||||
if authorization.get('bot_id') != bot_id:
|
||||
continue
|
||||
if authorization.get('workspace_id') != workspace_id:
|
||||
continue
|
||||
if authorization.get('thread_id') != thread_id:
|
||||
continue
|
||||
if not authorization.get('available_apis', {}).get('steering_pull', False):
|
||||
continue
|
||||
candidates.append((session['status'].get('started_at', 0), run_id))
|
||||
|
||||
@@ -254,6 +254,10 @@ class TranscriptStore:
|
||||
self,
|
||||
conversation_id: str,
|
||||
limit: int = HARD_LIMIT,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
strict_thread: bool = False,
|
||||
) -> list[provider_message.Message]:
|
||||
"""Project Transcript rows into the legacy provider Message view.
|
||||
|
||||
@@ -265,6 +269,10 @@ class TranscriptStore:
|
||||
conversation_id=conversation_id,
|
||||
limit=limit,
|
||||
direction="backward",
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
thread_id=thread_id,
|
||||
strict_thread=strict_thread,
|
||||
)
|
||||
|
||||
messages: list[provider_message.Message] = []
|
||||
|
||||
@@ -236,12 +236,18 @@ class BoxService:
|
||||
if forced_template:
|
||||
template = forced_template
|
||||
else:
|
||||
template = (
|
||||
(query.pipeline_config or {})
|
||||
.get('ai', {})
|
||||
.get('local-agent', {})
|
||||
.get('box-session-id-template', '{launcher_type}_{launcher_id}')
|
||||
template = '{launcher_type}_{launcher_id}'
|
||||
pipeline_config = query.pipeline_config or {}
|
||||
ai_config = pipeline_config.get('ai', {}) if isinstance(pipeline_config, dict) else {}
|
||||
runner_selector = ai_config.get('runner', {}) if isinstance(ai_config, dict) else {}
|
||||
runner_id = runner_selector.get('id') if isinstance(runner_selector, dict) else None
|
||||
runner_configs = ai_config.get('runner_config', {}) if isinstance(ai_config, dict) else {}
|
||||
runner_config = runner_configs.get(runner_id, {}) if isinstance(runner_configs, dict) else {}
|
||||
configured_template = (
|
||||
runner_config.get('box-session-id-template') if isinstance(runner_config, dict) else None
|
||||
)
|
||||
if isinstance(configured_template, str) and configured_template:
|
||||
template = configured_template
|
||||
variables = dict(query.variables or {})
|
||||
launcher_type = getattr(query, 'launcher_type', None)
|
||||
if hasattr(launcher_type, 'value'):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Normalize AgentRunner config containers
|
||||
|
||||
Revision ID: 0004_migrate_runner_config
|
||||
Revises: 0003_add_rerank_models
|
||||
Revision ID: 0005_migrate_runner_config
|
||||
Revises: 0004_add_mcp_readme
|
||||
Create Date: 2026-05-10
|
||||
"""
|
||||
|
||||
@@ -11,8 +11,8 @@ from alembic import op
|
||||
|
||||
from langbot.pkg.agent.runner.config_migration import ConfigMigration
|
||||
|
||||
revision = '0004_migrate_runner_config'
|
||||
down_revision = '0003_add_rerank_models'
|
||||
revision = '0005_migrate_runner_config'
|
||||
down_revision = '0004_add_mcp_readme'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""add_event_log_and_transcript_tables
|
||||
|
||||
Revision ID: 58846a8d7a81
|
||||
Revises: 0004_migrate_runner_config
|
||||
Revises: 0005_migrate_runner_config
|
||||
Create Date: 2026-05-23 15:41:47.030841
|
||||
"""
|
||||
from alembic import op
|
||||
@@ -9,7 +9,7 @@ import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = '58846a8d7a81'
|
||||
down_revision = '0004_migrate_runner_config'
|
||||
down_revision = '0005_migrate_runner_config'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
@@ -117,6 +117,9 @@ class PreProcessor(stage.PipelineStage):
|
||||
self,
|
||||
runner_id: str | None,
|
||||
conversation_uuid: str | None,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
) -> list[provider_message.Message] | None:
|
||||
if not runner_id or not conversation_uuid or not self._has_declared_db_engine():
|
||||
return None
|
||||
@@ -125,7 +128,13 @@ class PreProcessor(stage.PipelineStage):
|
||||
from ...agent.runner.transcript_store import TranscriptStore
|
||||
|
||||
store = TranscriptStore(self.ap.persistence_mgr.get_db_engine())
|
||||
messages = await store.get_legacy_provider_messages(str(conversation_uuid))
|
||||
messages = await store.get_legacy_provider_messages(
|
||||
str(conversation_uuid),
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
thread_id=thread_id,
|
||||
strict_thread=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(
|
||||
f'Unable to load Transcript history view for conversation {conversation_uuid}: {e}'
|
||||
@@ -138,10 +147,15 @@ class PreProcessor(stage.PipelineStage):
|
||||
self,
|
||||
runner_id: str | None,
|
||||
conversation: typing.Any,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
) -> list[provider_message.Message]:
|
||||
transcript_messages = await self._load_agent_runner_history_messages(
|
||||
runner_id,
|
||||
getattr(conversation, 'uuid', None),
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
thread_id=getattr(conversation, 'thread_id', None),
|
||||
)
|
||||
if transcript_messages is not None:
|
||||
return transcript_messages
|
||||
@@ -213,7 +227,11 @@ class PreProcessor(stage.PipelineStage):
|
||||
# Attach resolved session state to the query.
|
||||
query.session = session
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = await self._resolve_history_messages(runner_id, conversation)
|
||||
query.messages = await self._resolve_history_messages(
|
||||
runner_id,
|
||||
conversation,
|
||||
bot_id=query.bot_uuid,
|
||||
)
|
||||
|
||||
if uses_host_models:
|
||||
query.use_funcs = []
|
||||
|
||||
@@ -120,7 +120,8 @@ def _validate_artifact_access(
|
||||
|
||||
Args:
|
||||
session: AgentRunSession dict with run_id and authorization snapshot
|
||||
artifact_metadata: Artifact metadata dict with conversation_id, run_id
|
||||
artifact_metadata: Artifact metadata dict with conversation_id, run_id,
|
||||
and Host-only scope fields when available
|
||||
operation: Operation name for error messages ('metadata' or 'read')
|
||||
|
||||
Returns:
|
||||
@@ -138,7 +139,7 @@ def _validate_artifact_access(
|
||||
|
||||
# Rule 2: Same conversation (requires artifact to have conversation_id)
|
||||
if artifact_conversation_id and session_conversation_id:
|
||||
if artifact_conversation_id == session_conversation_id:
|
||||
if artifact_conversation_id == session_conversation_id and _artifact_matches_run_scope(session, artifact_metadata):
|
||||
return True, None
|
||||
|
||||
# Rule 3: Deny - no matching authorization rule
|
||||
@@ -150,6 +151,21 @@ def _get_run_authorization(session: dict[str, Any]) -> dict[str, Any]:
|
||||
return session['authorization']
|
||||
|
||||
|
||||
def _artifact_matches_run_scope(session: dict[str, Any], artifact_metadata: dict[str, Any]) -> bool:
|
||||
authorization = _get_run_authorization(session)
|
||||
for scope_key in ('bot_id', 'workspace_id', 'thread_id'):
|
||||
if authorization.get(scope_key) != artifact_metadata.get(scope_key):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _public_artifact_metadata(artifact_metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
public_metadata = dict(artifact_metadata)
|
||||
for scope_key in ('bot_id', 'workspace_id', 'thread_id'):
|
||||
public_metadata.pop(scope_key, None)
|
||||
return public_metadata
|
||||
|
||||
|
||||
def _resolve_state_scope(
|
||||
session: dict[str, Any],
|
||||
scope: str,
|
||||
@@ -1864,7 +1880,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
metadata = await store.get_metadata(artifact_id)
|
||||
metadata = await store.get_authorization_metadata(artifact_id)
|
||||
if not metadata:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Artifact {artifact_id} not found'
|
||||
@@ -1875,7 +1891,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
if not is_allowed:
|
||||
return handler.ActionResponse.error(message=error_msg)
|
||||
|
||||
return handler.ActionResponse.success(data=metadata)
|
||||
return handler.ActionResponse.success(data=_public_artifact_metadata(metadata))
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'ARTIFACT_METADATA error: {e}', exc_info=True)
|
||||
return handler.ActionResponse.error(message=f'Artifact metadata error: {e}')
|
||||
@@ -1934,7 +1950,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
store = ArtifactStore(self.ap.persistence_mgr.get_db_engine())
|
||||
|
||||
try:
|
||||
metadata = await store.get_metadata(artifact_id)
|
||||
metadata = await store.get_authorization_metadata(artifact_id)
|
||||
if not metadata:
|
||||
return handler.ActionResponse.error(
|
||||
message=f'Artifact {artifact_id} not found'
|
||||
|
||||
@@ -1,511 +0,0 @@
|
||||
"""DeerFlow LangGraph API Runner
|
||||
|
||||
参考 astrbot 的 deerflow_agent_runner 实现,适配 LangBot 的 Runner 接口。
|
||||
|
||||
特点:
|
||||
- 使用 LangGraph HTTP API 接入 deer-flow 后端
|
||||
- 自动管理 thread_id(按 session 隔离)
|
||||
- 支持 SSE 流式响应解析
|
||||
- 支持 streaming/非流式两种输出
|
||||
- 处理 values / messages-tuple / custom 三种事件
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import typing
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
from langbot.pkg.provider import runner
|
||||
from langbot.pkg.core import app
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from langbot.libs.deerflow_api import client, errors, stream_utils
|
||||
|
||||
|
||||
_MAX_VALUES_HISTORY = 200
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StreamState:
|
||||
"""流式状态跟踪"""
|
||||
|
||||
latest_text: str = ''
|
||||
prev_text_for_streaming: str = ''
|
||||
clarification_text: str = ''
|
||||
task_failures: list[str] = field(default_factory=list)
|
||||
seen_message_ids: set[str] = field(default_factory=set)
|
||||
seen_message_order: deque[str] = field(default_factory=deque)
|
||||
no_id_message_fingerprints: dict[int, str] = field(default_factory=dict)
|
||||
baseline_initialized: bool = False
|
||||
has_values_text: bool = False
|
||||
run_values_messages: list[dict[str, typing.Any]] = field(default_factory=list)
|
||||
timed_out: bool = False
|
||||
|
||||
|
||||
@runner.runner_class('deerflow-api')
|
||||
class DeerFlowAPIRunner(runner.RequestRunner):
|
||||
"""DeerFlow LangGraph API 对话请求器"""
|
||||
|
||||
deerflow_client: client.AsyncDeerFlowClient
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
super().__init__(ap, pipeline_config)
|
||||
|
||||
cfg = self.pipeline_config['ai']['deerflow-api']
|
||||
|
||||
api_base = cfg.get('api-base', '').strip()
|
||||
if not api_base or not api_base.startswith(('http://', 'https://')):
|
||||
raise errors.DeerFlowAPIError(
|
||||
message='DeerFlow API Base URL 格式错误,必须以 http:// 或 https:// 开头',
|
||||
)
|
||||
|
||||
self.api_base = api_base
|
||||
self.api_key = cfg.get('api-key', '')
|
||||
self.auth_header = cfg.get('auth-header', '')
|
||||
self.assistant_id = cfg.get('assistant-id', 'lead_agent')
|
||||
self.model_name = cfg.get('model-name', '')
|
||||
self.thinking_enabled = bool(cfg.get('thinking-enabled', False))
|
||||
self.plan_mode = bool(cfg.get('plan-mode', False))
|
||||
self.subagent_enabled = bool(cfg.get('subagent-enabled', False))
|
||||
self.max_concurrent_subagents = int(cfg.get('max-concurrent-subagents', 3))
|
||||
self.timeout = int(cfg.get('timeout', 300))
|
||||
self.recursion_limit = int(cfg.get('recursion-limit', 1000))
|
||||
|
||||
self.deerflow_client = client.AsyncDeerFlowClient(
|
||||
api_base=self.api_base,
|
||||
api_key=self.api_key,
|
||||
auth_header=self.auth_header,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 辅助方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _fingerprint_message(self, message: dict[str, typing.Any]) -> str:
|
||||
try:
|
||||
raw = json.dumps(message, sort_keys=True, ensure_ascii=False, default=str)
|
||||
except (TypeError, ValueError):
|
||||
raw = repr(message)
|
||||
return hashlib.sha1(raw.encode('utf-8', errors='ignore')).hexdigest()
|
||||
|
||||
def _remember_seen_message_id(self, state: _StreamState, msg_id: str) -> None:
|
||||
if not msg_id or msg_id in state.seen_message_ids:
|
||||
return
|
||||
state.seen_message_ids.add(msg_id)
|
||||
state.seen_message_order.append(msg_id)
|
||||
while len(state.seen_message_order) > _MAX_VALUES_HISTORY:
|
||||
dropped = state.seen_message_order.popleft()
|
||||
state.seen_message_ids.discard(dropped)
|
||||
|
||||
def _extract_new_messages_from_values(
|
||||
self,
|
||||
values_messages: list[typing.Any],
|
||||
state: _StreamState,
|
||||
) -> list[dict[str, typing.Any]]:
|
||||
new_messages: list[dict[str, typing.Any]] = []
|
||||
no_id_indexes_seen: set[int] = set()
|
||||
for idx, msg in enumerate(values_messages):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
msg_id = stream_utils.get_message_id(msg)
|
||||
if msg_id:
|
||||
if msg_id in state.seen_message_ids:
|
||||
continue
|
||||
self._remember_seen_message_id(state, msg_id)
|
||||
new_messages.append(msg)
|
||||
continue
|
||||
|
||||
no_id_indexes_seen.add(idx)
|
||||
fp = self._fingerprint_message(msg)
|
||||
if state.no_id_message_fingerprints.get(idx) == fp:
|
||||
continue
|
||||
state.no_id_message_fingerprints[idx] = fp
|
||||
new_messages.append(msg)
|
||||
|
||||
for idx in list(state.no_id_message_fingerprints.keys()):
|
||||
if idx not in no_id_indexes_seen:
|
||||
state.no_id_message_fingerprints.pop(idx, None)
|
||||
return new_messages
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 用户输入处理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_user_content(
|
||||
self,
|
||||
prompt: str,
|
||||
image_urls: list[str],
|
||||
) -> typing.Any:
|
||||
"""构建 LangGraph 兼容的 user content(支持多模态)"""
|
||||
if not image_urls:
|
||||
return prompt
|
||||
|
||||
content: list[dict[str, typing.Any]] = []
|
||||
if prompt:
|
||||
content.append({'type': 'text', 'text': prompt})
|
||||
for url in image_urls:
|
||||
if not isinstance(url, str):
|
||||
continue
|
||||
url = url.strip()
|
||||
if not url:
|
||||
continue
|
||||
if url.startswith(('http://', 'https://', 'data:')):
|
||||
content.append({'type': 'image_url', 'image_url': {'url': url}})
|
||||
return content if content else prompt
|
||||
|
||||
def _preprocess_user_message(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""提取用户消息的纯文本与图片 URL 列表"""
|
||||
plain_text = ''
|
||||
image_urls: list[str] = []
|
||||
|
||||
if isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
elif isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
elif ce.type == 'image_base64':
|
||||
# 转换为 data URI 形式
|
||||
b64 = getattr(ce, 'image_base64', '')
|
||||
if b64:
|
||||
if not b64.startswith('data:'):
|
||||
b64 = f'data:image/png;base64,{b64}'
|
||||
image_urls.append(b64)
|
||||
elif ce.type == 'image_url':
|
||||
url = getattr(ce, 'image_url', '')
|
||||
if url:
|
||||
image_urls.append(url)
|
||||
|
||||
return plain_text, image_urls
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 请求构造
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
prompt: str,
|
||||
image_urls: list[str],
|
||||
system_prompt: str = '',
|
||||
) -> list[dict[str, typing.Any]]:
|
||||
messages: list[dict[str, typing.Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({'role': 'system', 'content': system_prompt})
|
||||
messages.append(
|
||||
{
|
||||
'role': 'user',
|
||||
'content': self._build_user_content(prompt, image_urls),
|
||||
}
|
||||
)
|
||||
return messages
|
||||
|
||||
def _build_runtime_configurable(self, thread_id: str) -> dict[str, typing.Any]:
|
||||
cfg: dict[str, typing.Any] = {
|
||||
'thread_id': thread_id,
|
||||
'thinking_enabled': self.thinking_enabled,
|
||||
'is_plan_mode': self.plan_mode,
|
||||
'subagent_enabled': self.subagent_enabled,
|
||||
}
|
||||
if self.subagent_enabled:
|
||||
cfg['max_concurrent_subagents'] = self.max_concurrent_subagents
|
||||
if self.model_name:
|
||||
cfg['model_name'] = self.model_name
|
||||
return cfg
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
thread_id: str,
|
||||
prompt: str,
|
||||
image_urls: list[str],
|
||||
system_prompt: str = '',
|
||||
) -> dict[str, typing.Any]:
|
||||
runtime_configurable = self._build_runtime_configurable(thread_id)
|
||||
return {
|
||||
'assistant_id': self.assistant_id,
|
||||
'input': {
|
||||
'messages': self._build_messages(prompt, image_urls, system_prompt),
|
||||
},
|
||||
'stream_mode': ['values', 'messages-tuple', 'custom'],
|
||||
# DeerFlow 2.0 从 config.configurable 读取运行时覆盖
|
||||
# 同时保留 context 字段做向后兼容
|
||||
'context': dict(runtime_configurable),
|
||||
'config': {
|
||||
'recursion_limit': self.recursion_limit,
|
||||
'configurable': runtime_configurable,
|
||||
},
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session/Thread 管理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _ensure_thread_id(self, query: pipeline_query.Query) -> str:
|
||||
"""从 query.session 取/创建 deerflow thread_id
|
||||
|
||||
LangBot 使用 `query.session.using_conversation.uuid` 持久化 conversation id,
|
||||
我们复用这个字段存储 deerflow thread_id(与 Dify Runner 同样做法)。
|
||||
"""
|
||||
thread_id = query.session.using_conversation.uuid or ''
|
||||
if thread_id:
|
||||
return thread_id
|
||||
|
||||
thread = await self.deerflow_client.create_thread(timeout=min(30, self.timeout))
|
||||
thread_id = thread.get('thread_id', '')
|
||||
if not thread_id:
|
||||
raise errors.DeerFlowAPIError(message=f'DeerFlow create thread 返回数据缺少 thread_id: {thread}')
|
||||
|
||||
query.session.using_conversation.uuid = thread_id
|
||||
return thread_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 流式事件处理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_values_event(
|
||||
self,
|
||||
data: typing.Any,
|
||||
state: _StreamState,
|
||||
) -> str | None:
|
||||
"""处理 values 事件,返回新的完整文本(增量基础上的全量)"""
|
||||
values_messages = stream_utils.extract_messages_from_values_data(data)
|
||||
if not values_messages:
|
||||
return None
|
||||
|
||||
new_messages: list[dict[str, typing.Any]] = []
|
||||
if not state.baseline_initialized:
|
||||
state.baseline_initialized = True
|
||||
for idx, msg in enumerate(values_messages):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
new_messages.append(msg)
|
||||
msg_id = stream_utils.get_message_id(msg)
|
||||
if msg_id:
|
||||
self._remember_seen_message_id(state, msg_id)
|
||||
continue
|
||||
state.no_id_message_fingerprints[idx] = self._fingerprint_message(msg)
|
||||
else:
|
||||
new_messages = self._extract_new_messages_from_values(values_messages, state)
|
||||
|
||||
latest_text = ''
|
||||
if new_messages:
|
||||
state.run_values_messages.extend(new_messages)
|
||||
if len(state.run_values_messages) > _MAX_VALUES_HISTORY:
|
||||
state.run_values_messages = state.run_values_messages[-_MAX_VALUES_HISTORY:]
|
||||
latest_text = stream_utils.extract_latest_ai_text(state.run_values_messages)
|
||||
if latest_text:
|
||||
state.has_values_text = True
|
||||
latest_clarification = stream_utils.extract_latest_clarification_text(
|
||||
state.run_values_messages,
|
||||
)
|
||||
if latest_clarification:
|
||||
state.clarification_text = latest_clarification
|
||||
|
||||
return latest_text or None
|
||||
|
||||
def _handle_message_event(
|
||||
self,
|
||||
data: typing.Any,
|
||||
state: _StreamState,
|
||||
) -> str | None:
|
||||
"""处理 messages-tuple 事件,返回增量文本
|
||||
|
||||
当 values 事件已经提供完整文本时,跳过 messages-tuple 的增量
|
||||
"""
|
||||
delta = stream_utils.extract_ai_delta_from_event_data(data)
|
||||
if delta and not state.has_values_text:
|
||||
state.latest_text += delta
|
||||
return delta
|
||||
|
||||
maybe_clar = stream_utils.extract_clarification_from_event_data(data)
|
||||
if maybe_clar:
|
||||
state.clarification_text = maybe_clar
|
||||
return None
|
||||
|
||||
def _build_final_text(self, state: _StreamState) -> str:
|
||||
"""构建最终输出文本"""
|
||||
if state.clarification_text:
|
||||
return state.clarification_text
|
||||
|
||||
# 优先使用最后一条 AI message 的文本
|
||||
latest_ai = stream_utils.extract_latest_ai_message(state.run_values_messages)
|
||||
if latest_ai:
|
||||
text = stream_utils.extract_text(latest_ai.get('content'))
|
||||
if text:
|
||||
if state.timed_out:
|
||||
text += f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
|
||||
return text
|
||||
|
||||
if state.latest_text:
|
||||
text = state.latest_text
|
||||
if state.timed_out:
|
||||
text += f'\n\nDeerFlow stream 在 {self.timeout}s 后超时,返回部分结果。'
|
||||
return text
|
||||
|
||||
# 提取任务失败信息作兜底
|
||||
failure_text = stream_utils.build_task_failure_summary(state.task_failures)
|
||||
if failure_text:
|
||||
return failure_text
|
||||
|
||||
return 'DeerFlow 返回空响应'
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 主流程
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _stream_messages_chunk(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""流式输出生成器"""
|
||||
plain_text, image_urls = self._preprocess_user_message(query)
|
||||
|
||||
system_prompt = ''
|
||||
# LangBot 的 pipeline 通常通过 prompt-preprocess 已注入 system prompt
|
||||
# 这里保持空,让 prompt-preprocess 的内容作为 user message 一并送给 deerflow
|
||||
|
||||
thread_id = await self._ensure_thread_id(query)
|
||||
payload = self._build_payload(
|
||||
thread_id=thread_id,
|
||||
prompt=plain_text or 'continue',
|
||||
image_urls=image_urls,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
state = _StreamState()
|
||||
prev_text = ''
|
||||
message_idx = 0
|
||||
|
||||
try:
|
||||
async for event in self.deerflow_client.stream_run(
|
||||
thread_id=thread_id,
|
||||
payload=payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = event.get('event')
|
||||
data = event.get('data')
|
||||
|
||||
if event_type == 'values':
|
||||
new_full = self._handle_values_event(data, state)
|
||||
if new_full and new_full != prev_text:
|
||||
delta = new_full[len(prev_text) :] if new_full.startswith(prev_text) else new_full
|
||||
prev_text = new_full
|
||||
if delta:
|
||||
message_idx += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=new_full,
|
||||
is_final=False,
|
||||
)
|
||||
continue
|
||||
|
||||
if event_type in {'messages-tuple', 'messages', 'message'}:
|
||||
delta = self._handle_message_event(data, state)
|
||||
if delta:
|
||||
prev_text = state.latest_text
|
||||
message_idx += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=prev_text,
|
||||
is_final=False,
|
||||
)
|
||||
continue
|
||||
|
||||
if event_type == 'custom':
|
||||
state.task_failures.extend(
|
||||
stream_utils.extract_task_failures_from_custom_event(data),
|
||||
)
|
||||
continue
|
||||
|
||||
if event_type == 'error':
|
||||
raise errors.DeerFlowAPIError(message=f'DeerFlow stream error event: {data}')
|
||||
|
||||
if event_type == 'end':
|
||||
break
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
self.ap.logger.warning(f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}')
|
||||
state.timed_out = True
|
||||
|
||||
# 最终消息
|
||||
final_text = self._build_final_text(state)
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=final_text,
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
async def _messages(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""非流式聚合输出"""
|
||||
plain_text, image_urls = self._preprocess_user_message(query)
|
||||
|
||||
thread_id = await self._ensure_thread_id(query)
|
||||
payload = self._build_payload(
|
||||
thread_id=thread_id,
|
||||
prompt=plain_text or 'continue',
|
||||
image_urls=image_urls,
|
||||
)
|
||||
|
||||
state = _StreamState()
|
||||
|
||||
try:
|
||||
async for event in self.deerflow_client.stream_run(
|
||||
thread_id=thread_id,
|
||||
payload=payload,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = event.get('event')
|
||||
data = event.get('data')
|
||||
|
||||
if event_type == 'values':
|
||||
self._handle_values_event(data, state)
|
||||
continue
|
||||
|
||||
if event_type in {'messages-tuple', 'messages', 'message'}:
|
||||
self._handle_message_event(data, state)
|
||||
continue
|
||||
|
||||
if event_type == 'custom':
|
||||
state.task_failures.extend(
|
||||
stream_utils.extract_task_failures_from_custom_event(data),
|
||||
)
|
||||
continue
|
||||
|
||||
if event_type == 'error':
|
||||
raise errors.DeerFlowAPIError(message=f'DeerFlow stream error event: {data}')
|
||||
|
||||
if event_type == 'end':
|
||||
break
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
self.ap.logger.warning(f'DeerFlow stream timed out after {self.timeout}s for thread_id={thread_id}')
|
||||
state.timed_out = True
|
||||
|
||||
final_text = self._build_final_text(state)
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=final_text,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""主入口:根据 adapter 是否支持流式输出,选择流式或非流式"""
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
msg_idx = 0
|
||||
async for msg in self._stream_messages_chunk(query):
|
||||
msg_idx += 1
|
||||
msg.msg_sequence = msg_idx
|
||||
yield msg
|
||||
else:
|
||||
async for msg in self._messages(query):
|
||||
yield msg
|
||||
@@ -1,351 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
|
||||
|
||||
from langbot.pkg.provider import runner
|
||||
from langbot.pkg.core import app
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from langbot.libs.weknora_api import client, errors
|
||||
|
||||
|
||||
@runner.runner_class('weknora-api')
|
||||
class WeKnoraAPIRunner(runner.RequestRunner):
|
||||
"""WeKnora API 对话请求器"""
|
||||
|
||||
weknora_client: client.AsyncWeKnoraClient
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
super().__init__(ap, pipeline_config)
|
||||
|
||||
valid_app_types = ['chat', 'agent']
|
||||
if self.pipeline_config['ai']['weknora-api']['app-type'] not in valid_app_types:
|
||||
raise errors.WeKnoraAPIError(
|
||||
f'不支持的 WeKnora 应用类型: {self.pipeline_config["ai"]["weknora-api"]["app-type"]}'
|
||||
)
|
||||
|
||||
api_key = self.pipeline_config['ai']['weknora-api'].get('api-key', '').strip()
|
||||
if not api_key:
|
||||
raise errors.WeKnoraAPIError(
|
||||
'WeKnora API Key 未配置,请在流水线的 WeKnora API 配置中填入 API Key '
|
||||
'(从 WeKnora 前端 设置 → API Keys 生成)'
|
||||
)
|
||||
|
||||
base_url = self.pipeline_config['ai']['weknora-api'].get('base-url', '').strip()
|
||||
if not base_url:
|
||||
raise errors.WeKnoraAPIError('WeKnora Base URL 未配置,请填入服务器地址,例如 http://localhost:8080/api/v1')
|
||||
|
||||
self.weknora_client = client.AsyncWeKnoraClient(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
async def _extract_plain_text(self, query: pipeline_query.Query) -> str:
|
||||
"""从用户消息中提取纯文本内容"""
|
||||
plain_text = ''
|
||||
if isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
elif isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
|
||||
if not plain_text:
|
||||
plain_text = self.pipeline_config['ai']['weknora-api'].get('base-prompt', '')
|
||||
|
||||
return plain_text
|
||||
|
||||
async def _ensure_session(self, query: pipeline_query.Query) -> str:
|
||||
"""确保会话存在,如果不存在则创建"""
|
||||
session_id = query.session.using_conversation.uuid or ''
|
||||
|
||||
if not session_id:
|
||||
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||
session_id = await self.weknora_client.create_session(title=f'IM Chat - {user_tag}')
|
||||
query.session.using_conversation.uuid = session_id
|
||||
|
||||
return session_id
|
||||
|
||||
async def _agent_chat_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用 Agent 智能对话(非流式聚合输出)"""
|
||||
session_id = await self._ensure_session(query)
|
||||
plain_text = await self._extract_plain_text(query)
|
||||
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||
|
||||
config = self.pipeline_config['ai']['weknora-api']
|
||||
agent_id = config.get('agent-id', 'builtin-smart-reasoning')
|
||||
knowledge_base_ids = config.get('knowledge-base-ids', [])
|
||||
web_search_enabled = config.get('web-search-enabled', False)
|
||||
timeout = config.get('timeout', 120)
|
||||
|
||||
full_answer = ''
|
||||
chunk = None
|
||||
|
||||
async for chunk in self.weknora_client.agent_chat(
|
||||
session_id=session_id,
|
||||
query=plain_text,
|
||||
user=user_tag,
|
||||
agent_id=agent_id,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
web_search_enabled=web_search_enabled,
|
||||
timeout=timeout,
|
||||
):
|
||||
self.ap.logger.debug('weknora-agent-chunk: ' + str(chunk))
|
||||
|
||||
response_type = chunk.get('response_type', '')
|
||||
content = chunk.get('content', '')
|
||||
|
||||
if response_type == 'tool_call':
|
||||
# 工具调用
|
||||
tool_data = chunk.get('data', {})
|
||||
tool_name = tool_data.get('tool_name', '')
|
||||
if tool_name:
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id=chunk.get('id', ''),
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name=tool_name,
|
||||
arguments=json.dumps(tool_data.get('arguments', {})),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
elif response_type == 'answer':
|
||||
if content:
|
||||
full_answer += content
|
||||
|
||||
elif response_type == 'error':
|
||||
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
|
||||
|
||||
if chunk is None:
|
||||
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
if full_answer:
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=full_answer,
|
||||
)
|
||||
|
||||
async def _chat_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用知识库 RAG 问答(非流式聚合输出)"""
|
||||
session_id = await self._ensure_session(query)
|
||||
plain_text = await self._extract_plain_text(query)
|
||||
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||
|
||||
config = self.pipeline_config['ai']['weknora-api']
|
||||
agent_id = config.get('agent-id', 'builtin-quick-answer')
|
||||
knowledge_base_ids = config.get('knowledge-base-ids', [])
|
||||
timeout = config.get('timeout', 120)
|
||||
|
||||
full_answer = ''
|
||||
chunk = None
|
||||
|
||||
async for chunk in self.weknora_client.knowledge_chat(
|
||||
session_id=session_id,
|
||||
query=plain_text,
|
||||
user=user_tag,
|
||||
agent_id=agent_id,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
timeout=timeout,
|
||||
):
|
||||
self.ap.logger.debug('weknora-chat-chunk: ' + str(chunk))
|
||||
|
||||
response_type = chunk.get('response_type', '')
|
||||
content = chunk.get('content', '')
|
||||
|
||||
if response_type == 'answer':
|
||||
if content:
|
||||
full_answer += content
|
||||
|
||||
elif response_type == 'error':
|
||||
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
|
||||
|
||||
if chunk is None:
|
||||
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
if full_answer:
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=full_answer,
|
||||
)
|
||||
|
||||
async def _agent_chat_messages_chunk(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""调用 Agent 智能对话(流式输出)"""
|
||||
session_id = await self._ensure_session(query)
|
||||
plain_text = await self._extract_plain_text(query)
|
||||
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||
|
||||
config = self.pipeline_config['ai']['weknora-api']
|
||||
agent_id = config.get('agent-id', 'builtin-smart-reasoning')
|
||||
knowledge_base_ids = config.get('knowledge-base-ids', [])
|
||||
web_search_enabled = config.get('web-search-enabled', False)
|
||||
timeout = config.get('timeout', 120)
|
||||
|
||||
pending_answer = ''
|
||||
message_idx = 0
|
||||
is_final = False
|
||||
chunk = None
|
||||
|
||||
async for chunk in self.weknora_client.agent_chat(
|
||||
session_id=session_id,
|
||||
query=plain_text,
|
||||
user=user_tag,
|
||||
agent_id=agent_id,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
web_search_enabled=web_search_enabled,
|
||||
timeout=timeout,
|
||||
):
|
||||
self.ap.logger.debug('weknora-agent-chunk: ' + str(chunk))
|
||||
|
||||
response_type = chunk.get('response_type', '')
|
||||
content = chunk.get('content', '')
|
||||
done = chunk.get('done', False)
|
||||
|
||||
if response_type == 'tool_call':
|
||||
tool_data = chunk.get('data', {})
|
||||
tool_name = tool_data.get('tool_name', '')
|
||||
if tool_name:
|
||||
message_idx += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id=chunk.get('id', ''),
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name=tool_name,
|
||||
arguments=json.dumps(tool_data.get('arguments', {})),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
elif response_type == 'answer':
|
||||
message_idx += 1
|
||||
if content:
|
||||
pending_answer += content
|
||||
|
||||
if done:
|
||||
is_final = True
|
||||
|
||||
# 每 8 个 chunk 输出一次,或最终输出
|
||||
if message_idx % 8 == 0 or is_final:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_answer,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
elif response_type == 'error':
|
||||
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
|
||||
|
||||
if chunk is None:
|
||||
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
# 确保最终消息已发出
|
||||
if not is_final and pending_answer:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_answer,
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
async def _chat_messages_chunk(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""调用知识库 RAG 问答(流式输出)"""
|
||||
session_id = await self._ensure_session(query)
|
||||
plain_text = await self._extract_plain_text(query)
|
||||
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||
|
||||
config = self.pipeline_config['ai']['weknora-api']
|
||||
agent_id = config.get('agent-id', 'builtin-quick-answer')
|
||||
knowledge_base_ids = config.get('knowledge-base-ids', [])
|
||||
timeout = config.get('timeout', 120)
|
||||
|
||||
pending_answer = ''
|
||||
message_idx = 0
|
||||
is_final = False
|
||||
chunk = None
|
||||
|
||||
async for chunk in self.weknora_client.knowledge_chat(
|
||||
session_id=session_id,
|
||||
query=plain_text,
|
||||
user=user_tag,
|
||||
agent_id=agent_id,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
timeout=timeout,
|
||||
):
|
||||
self.ap.logger.debug('weknora-chat-chunk: ' + str(chunk))
|
||||
|
||||
response_type = chunk.get('response_type', '')
|
||||
content = chunk.get('content', '')
|
||||
done = chunk.get('done', False)
|
||||
|
||||
if response_type == 'answer':
|
||||
message_idx += 1
|
||||
if content:
|
||||
pending_answer += content
|
||||
|
||||
if done:
|
||||
is_final = True
|
||||
|
||||
if message_idx % 8 == 0 or is_final:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_answer,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
elif response_type == 'error':
|
||||
raise errors.WeKnoraAPIError(f'WeKnora 服务错误: {content}')
|
||||
|
||||
if chunk is None:
|
||||
raise errors.WeKnoraAPIError('WeKnora API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
if not is_final and pending_answer:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_answer,
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行请求"""
|
||||
app_type = self.pipeline_config['ai']['weknora-api']['app-type']
|
||||
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
msg_idx = 0
|
||||
if app_type == 'agent':
|
||||
async for msg in self._agent_chat_messages_chunk(query):
|
||||
msg_idx += 1
|
||||
msg.msg_sequence = msg_idx
|
||||
yield msg
|
||||
elif app_type == 'chat':
|
||||
async for msg in self._chat_messages_chunk(query):
|
||||
msg_idx += 1
|
||||
msg.msg_sequence = msg_idx
|
||||
yield msg
|
||||
else:
|
||||
raise errors.WeKnoraAPIError(f'不支持的 WeKnora 应用类型: {app_type}')
|
||||
else:
|
||||
if app_type == 'agent':
|
||||
async for msg in self._agent_chat_messages(query):
|
||||
yield msg
|
||||
elif app_type == 'chat':
|
||||
async for msg in self._chat_messages(query):
|
||||
yield msg
|
||||
else:
|
||||
raise errors.WeKnoraAPIError(f'不支持的 WeKnora 应用类型: {app_type}')
|
||||
@@ -1038,6 +1038,8 @@ else:
|
||||
run_id=run_session.get('run_id') if isinstance(run_session, dict) else None,
|
||||
runner_id=run_session.get('runner_id') if isinstance(run_session, dict) else None,
|
||||
bot_id=getattr(query, 'bot_uuid', None),
|
||||
workspace_id=authorization.get('workspace_id'),
|
||||
thread_id=authorization.get('thread_id'),
|
||||
metadata=metadata,
|
||||
)
|
||||
artifact_ref = {
|
||||
|
||||
@@ -43,6 +43,9 @@ def make_session(
|
||||
plugin_identity: str = 'test/test-runner',
|
||||
resources: dict | None = None,
|
||||
conversation_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
available_apis: dict[str, bool] | None = None,
|
||||
state_policy: dict[str, typing.Any] | None = None,
|
||||
state_context: dict[str, typing.Any] | None = None,
|
||||
@@ -114,6 +117,9 @@ def make_session(
|
||||
'resources': res,
|
||||
'available_apis': apis,
|
||||
'conversation_id': conversation_id,
|
||||
'bot_id': bot_id,
|
||||
'workspace_id': workspace_id,
|
||||
'thread_id': thread_id,
|
||||
'state_policy': policy,
|
||||
'state_context': context,
|
||||
'authorized_ids': authorized_ids,
|
||||
|
||||
@@ -208,10 +208,20 @@ class TestArtifactAuthorization:
|
||||
class TestArtifactAccessValidation:
|
||||
"""Test _validate_artifact_access authorization rules."""
|
||||
|
||||
def _make_session(self, conversation_id: str | None):
|
||||
def _make_session(
|
||||
self,
|
||||
conversation_id: str | None,
|
||||
*,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
):
|
||||
return make_session(
|
||||
run_id="run_001",
|
||||
conversation_id=conversation_id,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
thread_id=thread_id,
|
||||
available_apis={"artifact_metadata": True, "artifact_read": True},
|
||||
)
|
||||
|
||||
@@ -259,6 +269,64 @@ class TestArtifactAccessValidation:
|
||||
assert is_allowed is True
|
||||
assert error is None
|
||||
|
||||
def test_same_conversation_and_scope_allowed(self):
|
||||
"""Artifacts in the same run scope are allowed across runs."""
|
||||
session = self._make_session(
|
||||
"conv_001",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
)
|
||||
metadata = {
|
||||
"artifact_id": "art_001",
|
||||
"conversation_id": "conv_001",
|
||||
"run_id": "run_other",
|
||||
"bot_id": "bot_001",
|
||||
"workspace_id": "workspace_001",
|
||||
"thread_id": "thread_001",
|
||||
}
|
||||
|
||||
is_allowed, error = self._call_validate(session, metadata)
|
||||
assert is_allowed is True
|
||||
assert error is None
|
||||
|
||||
def test_same_conversation_different_scope_denied(self):
|
||||
"""Artifacts in another bot/thread scope are denied even in the same conversation."""
|
||||
session = self._make_session(
|
||||
"conv_001",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
)
|
||||
metadata = {
|
||||
"artifact_id": "art_001",
|
||||
"conversation_id": "conv_001",
|
||||
"run_id": "run_other",
|
||||
"bot_id": "bot_002",
|
||||
"workspace_id": "workspace_001",
|
||||
"thread_id": "thread_001",
|
||||
}
|
||||
|
||||
is_allowed, error = self._call_validate(session, metadata)
|
||||
assert is_allowed is False
|
||||
assert "denied" in error.lower()
|
||||
|
||||
def test_same_conversation_missing_scope_denied_for_scoped_session(self):
|
||||
"""Scoped runs should not read legacy-scope artifacts from other runs."""
|
||||
session = self._make_session("conv_001", bot_id="bot_001")
|
||||
metadata = {
|
||||
"artifact_id": "art_001",
|
||||
"conversation_id": "conv_001",
|
||||
"run_id": "run_other",
|
||||
"bot_id": None,
|
||||
"workspace_id": None,
|
||||
"thread_id": None,
|
||||
}
|
||||
|
||||
is_allowed, error = self._call_validate(session, metadata)
|
||||
assert is_allowed is False
|
||||
assert "denied" in error.lower()
|
||||
|
||||
def test_different_conversation_and_run_denied(self):
|
||||
"""Artifacts in different conversation and different run are denied."""
|
||||
session = self._make_session("conv_001")
|
||||
@@ -470,6 +538,9 @@ class TestArtifactStoreRealSQLite:
|
||||
content=content,
|
||||
conversation_id="conv_001",
|
||||
run_id="run_001",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
)
|
||||
|
||||
assert artifact_id == "art_real_001"
|
||||
@@ -489,6 +560,14 @@ class TestArtifactStoreRealSQLite:
|
||||
assert "storage_type" not in metadata
|
||||
assert "bot_id" not in metadata
|
||||
assert "workspace_id" not in metadata
|
||||
assert "thread_id" not in metadata
|
||||
assert "_langbot_thread_id" not in metadata.get("metadata", {})
|
||||
|
||||
auth_metadata = await store.get_authorization_metadata(artifact_id)
|
||||
assert auth_metadata is not None
|
||||
assert auth_metadata["bot_id"] == "bot_001"
|
||||
assert auth_metadata["workspace_id"] == "workspace_001"
|
||||
assert auth_metadata["thread_id"] == "thread_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_artifact_round_trip(self, db_engine):
|
||||
|
||||
@@ -628,6 +628,52 @@ class TestTranscriptStoreRealSQLite:
|
||||
assert messages[0].content[0].text == "User structured text"
|
||||
assert messages[1].content == "Assistant text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_legacy_provider_messages_filters_scope(self, db_engine):
|
||||
"""Legacy Pipeline history projection must stay inside the current run scope."""
|
||||
store = TranscriptStore(db_engine)
|
||||
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_scope_001",
|
||||
event_id="evt_scope_001",
|
||||
conversation_id="conv_scope",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
role="user",
|
||||
content="Current scope text",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_scope_002",
|
||||
event_id="evt_scope_002",
|
||||
conversation_id="conv_scope",
|
||||
bot_id="bot_002",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
role="assistant",
|
||||
content="Other bot text",
|
||||
)
|
||||
await store.append_transcript(
|
||||
transcript_id="trans_scope_003",
|
||||
event_id="evt_scope_003",
|
||||
conversation_id="conv_scope",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_002",
|
||||
role="assistant",
|
||||
content="Other thread text",
|
||||
)
|
||||
|
||||
messages = await store.get_legacy_provider_messages(
|
||||
"conv_scope",
|
||||
bot_id="bot_001",
|
||||
workspace_id="workspace_001",
|
||||
thread_id="thread_001",
|
||||
strict_thread=True,
|
||||
)
|
||||
|
||||
assert [message.content for message in messages] == ["Current scope text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_transcript_real_db(self, db_engine):
|
||||
"""Test search_transcript with real DB."""
|
||||
|
||||
@@ -193,6 +193,41 @@ async def test_build_models_authorizes_rerank_and_llm_refs_from_config(app):
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_resources_accepts_dynamic_form_type_aliases(app):
|
||||
"""Frontend DynamicForm aliases should resolve to runtime resource grants."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
|
||||
async def get_kb(kb_uuid):
|
||||
return SimpleNamespace(
|
||||
uuid=kb_uuid,
|
||||
get_name=lambda: f'name-{kb_uuid}',
|
||||
knowledge_base_entity=SimpleNamespace(kb_type='default'),
|
||||
)
|
||||
|
||||
app.rag_mgr.get_knowledge_base_by_uuid = AsyncMock(side_effect=get_kb)
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'knowledge_retrieval': True},
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'select-llm-model'},
|
||||
{'name': 'knowledge-bases', 'type': 'select-knowledge-bases'},
|
||||
],
|
||||
)
|
||||
query = make_query({
|
||||
'model': 'llm_alias',
|
||||
'knowledge-bases': ['kb_alias'],
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'llm_alias', 'model_type': 'llm', 'provider': 'test-provider', 'operations': ['invoke', 'stream']},
|
||||
]
|
||||
assert resources['knowledge_bases'] == [
|
||||
{'kb_id': 'kb_alias', 'kb_name': 'name-kb_alias', 'kb_type': 'default', 'operations': ['list', 'retrieve']},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_manifest_permission_narrows_binding(app):
|
||||
"""Manifest model permissions narrower than binding should remove LLM grants."""
|
||||
|
||||
@@ -288,6 +288,45 @@ class TestSessionRegistryBasic:
|
||||
assert len(items) == MAX_STEERING_QUEUE_ITEMS
|
||||
assert all(item['event']['event_id'] != 'overflow' for item in items)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_steering_target_requires_same_scope(self):
|
||||
"""Steering claims must not cross bot/workspace/thread boundaries."""
|
||||
registry = AgentRunSessionRegistry()
|
||||
await registry.register(
|
||||
run_id='run_steering_scoped',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=make_resources(),
|
||||
conversation_id='conv_1',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_1',
|
||||
available_apis={'steering_pull': True},
|
||||
)
|
||||
|
||||
assert await registry.find_steering_target(
|
||||
conversation_id='conv_1',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_1',
|
||||
) == 'run_steering_scoped'
|
||||
assert await registry.find_steering_target(
|
||||
conversation_id='conv_1',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
bot_id='bot_2',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_1',
|
||||
) is None
|
||||
assert await registry.find_steering_target(
|
||||
conversation_id='conv_1',
|
||||
runner_id='plugin:test/my-runner/default',
|
||||
bot_id='bot_1',
|
||||
workspace_id='workspace_1',
|
||||
thread_id='thread_2',
|
||||
) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_returns_pending_steering_queue(self):
|
||||
"""Unregister returns the removed session so callers can audit pending steering."""
|
||||
|
||||
@@ -630,7 +630,9 @@ class TestMCPServiceTestMCPServer:
|
||||
ap.tool_mgr.mcp_tool_loader = SimpleNamespace()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.server_name = 'transient-test-server'
|
||||
mock_session.start = AsyncMock()
|
||||
mock_session.shutdown = AsyncMock()
|
||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server = AsyncMock(return_value=mock_session)
|
||||
|
||||
ap.task_mgr = SimpleNamespace()
|
||||
@@ -645,4 +647,9 @@ class TestMCPServiceTestMCPServer:
|
||||
|
||||
# Verify - load_mcp_server called
|
||||
ap.tool_mgr.mcp_tool_loader.load_mcp_server.assert_called_once()
|
||||
assert task_id == 456
|
||||
assert task_id == 456
|
||||
|
||||
coroutine = ap.task_mgr.create_user_task.call_args.args[0]
|
||||
await coroutine
|
||||
mock_session.start.assert_awaited_once()
|
||||
mock_session.shutdown.assert_awaited_once()
|
||||
|
||||
@@ -181,6 +181,23 @@ def make_app(
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_box_session_id_reads_current_runner_config():
|
||||
query = make_query(101)
|
||||
query.pipeline_config = {
|
||||
'ai': {
|
||||
'runner': {'id': 'plugin:langbot/local-agent/default'},
|
||||
'runner_config': {
|
||||
'plugin:langbot/local-agent/default': {
|
||||
'box-session-id-template': 'bot-{launcher_id}-{sender_id}',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
service = BoxService(make_app(Mock()), client=Mock(spec=BoxRuntimeClient))
|
||||
|
||||
assert service.resolve_box_session_id(query) == 'bot-test_user-test_user'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_box_service_without_explicit_client_initializes_internal_connector(monkeypatch: pytest.MonkeyPatch):
|
||||
connector = Mock()
|
||||
|
||||
@@ -273,6 +273,13 @@ async def test_preproc_uses_transcript_history_view_when_available():
|
||||
|
||||
assert result.result_type == entities_module.ResultType.CONTINUE
|
||||
assert query.messages == transcript_messages
|
||||
stage._load_agent_runner_history_messages.assert_awaited_once_with(
|
||||
'plugin:langbot/local-agent/default',
|
||||
'conv-1',
|
||||
bot_id='bot-1',
|
||||
workspace_id=None,
|
||||
thread_id=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user