mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-10 15:56:03 +00:00
feat(agent-runner): enforce typed host permissions
This commit is contained in:
@@ -26,49 +26,26 @@ def iter_schema_items(
|
||||
yield item
|
||||
|
||||
|
||||
def has_permission(
|
||||
descriptor: AgentRunnerDescriptor | None,
|
||||
name: str,
|
||||
actions: set[str],
|
||||
) -> bool:
|
||||
"""Return whether a runner descriptor requests one of the given actions."""
|
||||
if descriptor is None:
|
||||
return False
|
||||
configured_actions = descriptor.permissions.get(name, [])
|
||||
return any(action in configured_actions for action in actions)
|
||||
|
||||
|
||||
def uses_host_models(descriptor: AgentRunnerDescriptor | None) -> bool:
|
||||
"""Return whether LangBot should resolve model resources for this runner."""
|
||||
return (
|
||||
has_permission(descriptor, 'models', {'invoke', 'stream', 'list'})
|
||||
and any(True for _ in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES))
|
||||
)
|
||||
return any(True for _ in iter_schema_items(descriptor, LLM_MODEL_SELECTOR_TYPES))
|
||||
|
||||
|
||||
def uses_host_tools(descriptor: AgentRunnerDescriptor | None) -> bool:
|
||||
"""Return whether LangBot should expose tool resources to this runner."""
|
||||
return (
|
||||
descriptor is not None
|
||||
and descriptor.supports_tool_calling()
|
||||
and has_permission(descriptor, 'tools', {'list', 'detail', 'call'})
|
||||
)
|
||||
return descriptor is not None and descriptor.supports_tool_calling()
|
||||
|
||||
|
||||
def uses_host_knowledge_bases(descriptor: AgentRunnerDescriptor | None) -> bool:
|
||||
"""Return whether LangBot should expose knowledge-base resources to this runner."""
|
||||
return (
|
||||
descriptor is not None
|
||||
and descriptor.supports_knowledge_retrieval()
|
||||
and has_permission(descriptor, 'knowledge_bases', {'list', 'retrieve'})
|
||||
)
|
||||
return descriptor is not None and descriptor.supports_knowledge_retrieval()
|
||||
|
||||
|
||||
def supports_skill_authoring(descriptor: AgentRunnerDescriptor | None) -> bool:
|
||||
"""Return whether the runner wants Host skill-authoring tools."""
|
||||
if descriptor is None:
|
||||
return False
|
||||
return bool(descriptor.capabilities.get('skill_authoring', False))
|
||||
return descriptor.capabilities.skill_authoring
|
||||
|
||||
|
||||
def extract_prompt_config(
|
||||
|
||||
@@ -44,7 +44,6 @@ class AgentInput(typing.TypedDict):
|
||||
|
||||
text: str | None
|
||||
contents: list[dict[str, typing.Any]]
|
||||
message_chain: dict[str, typing.Any] | None
|
||||
attachments: list[dict[str, typing.Any]]
|
||||
|
||||
|
||||
@@ -254,7 +253,6 @@ class AgentRunContextBuilder:
|
||||
input: AgentInput = {
|
||||
'text': event.input.text,
|
||||
'contents': [c.model_dump(mode='json') if hasattr(c, 'model_dump') else c for c in event.input.contents],
|
||||
'message_chain': event.input.message_chain,
|
||||
'attachments': [
|
||||
a.model_dump(mode='json') if hasattr(a, 'model_dump') else a for a in event.input.attachments
|
||||
],
|
||||
@@ -361,28 +359,33 @@ class AgentRunContextBuilder:
|
||||
ContextAccess dict
|
||||
"""
|
||||
conversation_id = event.conversation_id
|
||||
permissions = descriptor.permissions
|
||||
history_perms = set(permissions.history)
|
||||
event_perms = set(permissions.events)
|
||||
artifact_perms = set(permissions.artifacts)
|
||||
storage_perms = set(permissions.storage)
|
||||
|
||||
# Check if history APIs are available for this runner
|
||||
# Based on runner permissions
|
||||
permissions = descriptor.permissions or {}
|
||||
history_permissions = permissions.get('history', [])
|
||||
event_permissions = permissions.get('events', [])
|
||||
artifact_permissions = permissions.get('artifacts', [])
|
||||
|
||||
history_page_enabled = 'page' in history_permissions and conversation_id is not None
|
||||
history_search_enabled = 'search' in history_permissions and conversation_id is not None
|
||||
event_get_enabled = 'get' in event_permissions
|
||||
event_page_enabled = 'page' in event_permissions and conversation_id is not None
|
||||
artifact_metadata_enabled = 'metadata' in artifact_permissions
|
||||
artifact_read_enabled = 'read' in artifact_permissions
|
||||
history_page_enabled = 'page' in history_perms and conversation_id is not None
|
||||
history_search_enabled = 'search' in history_perms and conversation_id is not None
|
||||
event_get_enabled = 'get' in event_perms
|
||||
event_page_enabled = 'page' in event_perms and conversation_id is not None
|
||||
artifact_metadata_enabled = 'metadata' in artifact_perms
|
||||
artifact_read_enabled = 'read' in artifact_perms
|
||||
|
||||
# Determine state API availability based on binding state_policy.
|
||||
state_enabled = False
|
||||
storage_enabled = False
|
||||
if binding is not None:
|
||||
state_policy = binding.state_policy
|
||||
if state_policy.enable_state and state_policy.state_scopes:
|
||||
state_enabled = True
|
||||
|
||||
resource_policy = binding.resource_policy
|
||||
storage_enabled = (
|
||||
('plugin' in storage_perms and resource_policy.allow_plugin_storage)
|
||||
or ('workspace' in storage_perms and resource_policy.allow_workspace_storage)
|
||||
)
|
||||
|
||||
# Get latest cursor and has_history_before if conversation exists
|
||||
latest_cursor = None
|
||||
has_history_before = False
|
||||
@@ -411,7 +414,7 @@ class AgentRunContextBuilder:
|
||||
'delivered_count': 0,
|
||||
'source_total_count': None,
|
||||
'messages_complete': False,
|
||||
'reason': 'self_managed_context',
|
||||
'reason': 'current_event_only',
|
||||
},
|
||||
'available_apis': {
|
||||
'history_page': history_page_enabled,
|
||||
@@ -421,6 +424,6 @@ class AgentRunContextBuilder:
|
||||
'artifact_metadata': artifact_metadata_enabled,
|
||||
'artifact_read': artifact_read_enabled,
|
||||
'state': state_enabled,
|
||||
'storage': True,
|
||||
'storage': storage_enabled,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -4,6 +4,11 @@ from __future__ import annotations
|
||||
import typing
|
||||
import pydantic
|
||||
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.manifest import (
|
||||
AgentRunnerCapabilities,
|
||||
AgentRunnerPermissions,
|
||||
)
|
||||
|
||||
|
||||
class AgentRunnerDescriptor(pydantic.BaseModel):
|
||||
"""Descriptor for an agent runner.
|
||||
@@ -36,16 +41,20 @@ class AgentRunnerDescriptor(pydantic.BaseModel):
|
||||
plugin_version: str | None = None
|
||||
"""Optional plugin version"""
|
||||
|
||||
config_schema: list[dict[str, typing.Any]] = []
|
||||
config_schema: list[dict[str, typing.Any]] = pydantic.Field(default_factory=list)
|
||||
"""Configuration schema using DynamicForm format"""
|
||||
|
||||
capabilities: dict[str, bool] = {}
|
||||
capabilities: AgentRunnerCapabilities = pydantic.Field(
|
||||
default_factory=AgentRunnerCapabilities
|
||||
)
|
||||
"""Runner capabilities: streaming, tool_calling, knowledge_retrieval, etc."""
|
||||
|
||||
permissions: dict[str, list[str]] = {}
|
||||
"""Requested permissions: models, tools, knowledge_bases, storage, files, platform_api"""
|
||||
permissions: AgentRunnerPermissions = pydantic.Field(
|
||||
default_factory=AgentRunnerPermissions
|
||||
)
|
||||
"""Requested LangBot resource permissions."""
|
||||
|
||||
raw_manifest: dict[str, typing.Any] = {}
|
||||
raw_manifest: dict[str, typing.Any] = pydantic.Field(default_factory=dict)
|
||||
"""Original manifest for reference"""
|
||||
|
||||
model_config = pydantic.ConfigDict(
|
||||
@@ -58,12 +67,12 @@ class AgentRunnerDescriptor(pydantic.BaseModel):
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""Check if runner supports streaming output."""
|
||||
return self.capabilities.get('streaming', False)
|
||||
return self.capabilities.streaming
|
||||
|
||||
def supports_tool_calling(self) -> bool:
|
||||
"""Check if runner supports tool calling."""
|
||||
return self.capabilities.get('tool_calling', False)
|
||||
return self.capabilities.tool_calling
|
||||
|
||||
def supports_knowledge_retrieval(self) -> bool:
|
||||
"""Check if runner supports knowledge retrieval."""
|
||||
return self.capabilities.get('knowledge_retrieval', False)
|
||||
return self.capabilities.knowledge_retrieval
|
||||
|
||||
@@ -106,7 +106,7 @@ class AgentRunOrchestrator:
|
||||
query_id=session_query_id,
|
||||
plugin_identity=descriptor.get_plugin_id(),
|
||||
resources=resources,
|
||||
permissions=descriptor.permissions or {},
|
||||
available_apis=context.get('context', {}).get('available_apis'),
|
||||
conversation_id=event.conversation_id,
|
||||
state_policy={
|
||||
'enable_state': binding.state_policy.enable_state,
|
||||
@@ -137,6 +137,12 @@ class AgentRunOrchestrator:
|
||||
try:
|
||||
async for result_dict in self.invoker.invoke(descriptor, context):
|
||||
result_type = result_dict.get('type')
|
||||
if result_type and not self.result_normalizer.validate_payload(
|
||||
result_type,
|
||||
result_dict.get('data', {}),
|
||||
descriptor,
|
||||
):
|
||||
continue
|
||||
|
||||
if result_type == 'artifact.created':
|
||||
artifact_ref = await self.journal.handle_artifact_created(
|
||||
|
||||
@@ -396,18 +396,11 @@ class QueryEntryAdapter:
|
||||
if text_parts:
|
||||
text = ''.join(text_parts)
|
||||
|
||||
message_chain_dict = None
|
||||
message_chain = getattr(query, 'message_chain', None)
|
||||
if message_chain:
|
||||
if hasattr(message_chain, 'model_dump'):
|
||||
message_chain_dict = message_chain.model_dump(mode='json')
|
||||
|
||||
attachments = cls._build_attachments(query, contents)
|
||||
|
||||
return AgentInput(
|
||||
text=text,
|
||||
contents=contents,
|
||||
message_chain=message_chain_dict,
|
||||
attachments=attachments,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,11 @@ from __future__ import annotations
|
||||
import typing
|
||||
import asyncio
|
||||
|
||||
import pydantic
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.manifest import (
|
||||
AgentRunnerManifest,
|
||||
)
|
||||
|
||||
from ...core import app
|
||||
from .descriptor import AgentRunnerDescriptor
|
||||
from .id import parse_runner_id, format_runner_id
|
||||
@@ -79,7 +84,7 @@ class AgentRunnerRegistry:
|
||||
Args:
|
||||
runner_data: Raw runner data from plugin runtime with fields:
|
||||
- plugin_author, plugin_name, runner_name
|
||||
- manifest (full component manifest dict)
|
||||
- manifest (typed AgentRunnerManifest or legacy component manifest)
|
||||
- capabilities, permissions, config (extracted from spec)
|
||||
|
||||
Returns:
|
||||
@@ -93,17 +98,77 @@ class AgentRunnerRegistry:
|
||||
return None
|
||||
|
||||
manifest = runner_data.get('manifest', {})
|
||||
runner_id = format_runner_id(
|
||||
source='plugin',
|
||||
plugin_author=plugin_author,
|
||||
plugin_name=plugin_name,
|
||||
runner_name=runner_name,
|
||||
)
|
||||
|
||||
is_typed_manifest = self._looks_like_typed_manifest(manifest)
|
||||
if is_typed_manifest:
|
||||
typed_manifest = AgentRunnerManifest.model_validate(manifest)
|
||||
else:
|
||||
typed_manifest = self._build_typed_manifest_from_legacy_data(
|
||||
runner_id=runner_id,
|
||||
runner_name=runner_name,
|
||||
runner_data=runner_data,
|
||||
manifest=manifest,
|
||||
)
|
||||
|
||||
if runner_data.get('config'):
|
||||
config_schema = runner_data['config']
|
||||
elif not is_typed_manifest and isinstance(manifest.get('spec'), dict):
|
||||
config_schema = manifest['spec'].get('config', [])
|
||||
else:
|
||||
config_schema = [
|
||||
item.model_dump(mode='json') for item in typed_manifest.config_schema
|
||||
]
|
||||
|
||||
return AgentRunnerDescriptor(
|
||||
id=runner_id,
|
||||
source='plugin',
|
||||
label=typed_manifest.label,
|
||||
description=typed_manifest.description or runner_data.get('runner_description'),
|
||||
plugin_author=plugin_author,
|
||||
plugin_name=plugin_name,
|
||||
runner_name=runner_name,
|
||||
plugin_version=runner_data.get('plugin_version'),
|
||||
config_schema=config_schema,
|
||||
capabilities=typed_manifest.capabilities,
|
||||
permissions=typed_manifest.permissions,
|
||||
raw_manifest=manifest,
|
||||
)
|
||||
|
||||
def _looks_like_typed_manifest(self, manifest: dict[str, typing.Any]) -> bool:
|
||||
"""Return whether manifest is the SDK typed AgentRunnerManifest shape."""
|
||||
return (
|
||||
isinstance(manifest, dict)
|
||||
and 'id' in manifest
|
||||
and 'name' in manifest
|
||||
and 'label' in manifest
|
||||
)
|
||||
|
||||
def _build_typed_manifest_from_legacy_data(
|
||||
self,
|
||||
*,
|
||||
runner_id: str,
|
||||
runner_name: str,
|
||||
runner_data: dict[str, typing.Any],
|
||||
manifest: dict[str, typing.Any],
|
||||
) -> AgentRunnerManifest:
|
||||
"""Validate legacy raw component manifest data as typed runner manifest."""
|
||||
|
||||
# Validate kind
|
||||
kind = manifest.get('kind', '')
|
||||
if kind != 'AgentRunner':
|
||||
return None
|
||||
raise ValueError(f'Invalid AgentRunner kind: {kind or "<missing>"}')
|
||||
|
||||
# Validate metadata
|
||||
metadata = manifest.get('metadata', {})
|
||||
name = metadata.get('name', '')
|
||||
if not name:
|
||||
return None
|
||||
raise ValueError('Missing AgentRunner metadata.name')
|
||||
|
||||
# metadata.label must exist
|
||||
label = metadata.get('label', {})
|
||||
@@ -118,28 +183,20 @@ class AgentRunnerRegistry:
|
||||
capabilities = runner_data.get('capabilities') or spec.get('capabilities', {})
|
||||
permissions = runner_data.get('permissions') or spec.get('permissions', {})
|
||||
|
||||
# Build descriptor
|
||||
runner_id = format_runner_id(
|
||||
source='plugin',
|
||||
plugin_author=plugin_author,
|
||||
plugin_name=plugin_name,
|
||||
runner_name=runner_name,
|
||||
)
|
||||
|
||||
return AgentRunnerDescriptor(
|
||||
id=runner_id,
|
||||
source='plugin',
|
||||
label=label,
|
||||
description=metadata.get('description') or runner_data.get('runner_description'),
|
||||
plugin_author=plugin_author,
|
||||
plugin_name=plugin_name,
|
||||
runner_name=runner_name,
|
||||
plugin_version=runner_data.get('plugin_version'),
|
||||
config_schema=config_schema,
|
||||
capabilities=capabilities,
|
||||
permissions=permissions,
|
||||
raw_manifest=manifest,
|
||||
)
|
||||
try:
|
||||
return AgentRunnerManifest(
|
||||
id=runner_id,
|
||||
name=runner_name,
|
||||
label=label,
|
||||
description=metadata.get('description') or runner_data.get('runner_description'),
|
||||
capabilities=capabilities,
|
||||
permissions=permissions,
|
||||
config_schema=config_schema,
|
||||
)
|
||||
except pydantic.ValidationError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise ValueError(f'Invalid AgentRunner manifest: {exc}') from exc
|
||||
|
||||
async def refresh(self) -> None:
|
||||
"""Refresh runner cache.
|
||||
|
||||
@@ -18,17 +18,14 @@ from .host_models import AgentEventEnvelope, AgentBinding
|
||||
|
||||
|
||||
class AgentResourceBuilder:
|
||||
"""Builder for constructing AgentResources with permission filtering.
|
||||
"""Builder for constructing run-scoped AgentResources with permission filtering.
|
||||
|
||||
Responsibilities:
|
||||
- Apply 3-layer permission filtering:
|
||||
1. Runner manifest declared permissions
|
||||
2. Pipeline extensions_preference (bound plugins/MCP servers)
|
||||
3. Agent/runner config selected resources
|
||||
- Apply manifest permissions intersected with binding resource policy
|
||||
- Build models list from authorized models
|
||||
- Build tools list from bound plugins/MCP servers
|
||||
- Build knowledge_bases list from config
|
||||
- Build storage and files permissions summary
|
||||
- Build storage and files access summary
|
||||
|
||||
Note: This only builds the resource declaration. The actual proxy actions
|
||||
in handler.py must still validate against ctx.resources at runtime.
|
||||
@@ -59,26 +56,21 @@ class AgentResourceBuilder:
|
||||
Args:
|
||||
event: Event envelope
|
||||
binding: Agent binding with resource policy
|
||||
descriptor: Runner descriptor with permissions and capabilities
|
||||
descriptor: Runner descriptor with capabilities, permissions, and config schema
|
||||
|
||||
Returns:
|
||||
AgentResources dict with filtered resource lists
|
||||
"""
|
||||
# Layer 1: Runner manifest permissions
|
||||
manifest_perms = descriptor.permissions
|
||||
|
||||
# Layer 2: Binding resource policy
|
||||
resource_policy = binding.resource_policy
|
||||
|
||||
# Layer 3: Agent/runner config
|
||||
runner_config = binding.runner_config
|
||||
manifest_perms = descriptor.permissions
|
||||
|
||||
# Build each resource category
|
||||
models = await self._build_models_from_binding(
|
||||
manifest_perms, resource_policy, descriptor, runner_config
|
||||
)
|
||||
tools = await self._build_tools_from_binding(
|
||||
manifest_perms, resource_policy, binding
|
||||
manifest_perms, resource_policy, descriptor
|
||||
)
|
||||
knowledge_bases = await self._build_knowledge_bases_from_binding(
|
||||
manifest_perms, resource_policy, descriptor, runner_config
|
||||
@@ -100,7 +92,7 @@ class AgentResourceBuilder:
|
||||
|
||||
async def _build_models_from_binding(
|
||||
self,
|
||||
manifest_perms: dict[str, list[str]],
|
||||
manifest_perms: typing.Any,
|
||||
resource_policy: typing.Any,
|
||||
descriptor: AgentRunnerDescriptor,
|
||||
runner_config: dict[str, typing.Any],
|
||||
@@ -109,10 +101,10 @@ class AgentResourceBuilder:
|
||||
models: list[ModelResource] = []
|
||||
seen_model_ids: set[str] = set()
|
||||
|
||||
model_perms = manifest_perms.get('models', [])
|
||||
allow_llm = 'invoke' in model_perms or 'stream' in model_perms
|
||||
allow_rerank = 'rerank' in model_perms
|
||||
if not allow_llm and not allow_rerank:
|
||||
model_perms = set(manifest_perms.models)
|
||||
include_llm = bool({'invoke', 'stream'} & model_perms)
|
||||
include_rerank = 'rerank' in model_perms
|
||||
if not include_llm and not include_rerank:
|
||||
return models
|
||||
|
||||
# Get additional model UUID grants from resource policy.
|
||||
@@ -124,12 +116,12 @@ class AgentResourceBuilder:
|
||||
seen_model_ids=seen_model_ids,
|
||||
descriptor=descriptor,
|
||||
runner_config=runner_config,
|
||||
include_llm=allow_llm,
|
||||
include_rerank=allow_rerank,
|
||||
include_llm=include_llm,
|
||||
include_rerank=include_rerank,
|
||||
)
|
||||
|
||||
# Add explicitly allowed models
|
||||
if allowed_uuids and allow_llm:
|
||||
if allowed_uuids and include_llm:
|
||||
for model_uuid in allowed_uuids:
|
||||
await self._append_llm_model_resource(models, seen_model_ids, model_uuid)
|
||||
|
||||
@@ -137,16 +129,17 @@ class AgentResourceBuilder:
|
||||
|
||||
async def _build_tools_from_binding(
|
||||
self,
|
||||
manifest_perms: dict[str, list[str]],
|
||||
manifest_perms: typing.Any,
|
||||
resource_policy: typing.Any,
|
||||
binding: AgentBinding,
|
||||
descriptor: AgentRunnerDescriptor,
|
||||
) -> list[ToolResource]:
|
||||
"""Build tools list from binding."""
|
||||
tools: list[ToolResource] = []
|
||||
tool_perms = set(manifest_perms.tools)
|
||||
if not ({'detail', 'call'} & tool_perms):
|
||||
return tools
|
||||
|
||||
# Check manifest permission
|
||||
tool_perms = manifest_perms.get('tools', [])
|
||||
if 'detail' not in tool_perms and 'call' not in tool_perms:
|
||||
if not config_schema.uses_host_tools(descriptor):
|
||||
return tools
|
||||
|
||||
# Get tool names from resource policy
|
||||
@@ -164,17 +157,18 @@ class AgentResourceBuilder:
|
||||
|
||||
async def _build_knowledge_bases_from_binding(
|
||||
self,
|
||||
manifest_perms: dict[str, list[str]],
|
||||
manifest_perms: typing.Any,
|
||||
resource_policy: typing.Any,
|
||||
descriptor: AgentRunnerDescriptor,
|
||||
runner_config: dict[str, typing.Any],
|
||||
) -> list[KnowledgeBaseResource]:
|
||||
"""Build knowledge bases list from binding."""
|
||||
kb_resources: list[KnowledgeBaseResource] = []
|
||||
kb_perms = set(manifest_perms.knowledge_bases)
|
||||
if not ({'list', 'retrieve'} & kb_perms):
|
||||
return kb_resources
|
||||
|
||||
# Check manifest permission
|
||||
kb_perms = manifest_perms.get('knowledge_bases', [])
|
||||
if 'list' not in kb_perms and 'retrieve' not in kb_perms:
|
||||
if not config_schema.uses_host_knowledge_bases(descriptor):
|
||||
return kb_resources
|
||||
|
||||
# Get KB UUID grants from schema-defined config fields.
|
||||
@@ -231,12 +225,12 @@ class AgentResourceBuilder:
|
||||
|
||||
def _build_storage_from_binding(
|
||||
self,
|
||||
manifest_perms: dict[str, list[str]],
|
||||
manifest_perms: typing.Any,
|
||||
binding: AgentBinding,
|
||||
) -> StorageResource:
|
||||
"""Build storage permissions from binding."""
|
||||
storage_perms = manifest_perms.get('storage', [])
|
||||
"""Build storage access summary from manifest and binding policy."""
|
||||
resource_policy = binding.resource_policy
|
||||
storage_perms = set(manifest_perms.storage)
|
||||
|
||||
return {
|
||||
'plugin_storage': 'plugin' in storage_perms and resource_policy.allow_plugin_storage,
|
||||
|
||||
@@ -3,6 +3,16 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.result import (
|
||||
ActionRequestedPayload,
|
||||
ArtifactCreatedPayload,
|
||||
MessageCompletedPayload,
|
||||
MessageDeltaPayload,
|
||||
RunCompletedPayload,
|
||||
RunFailedPayload,
|
||||
StateUpdatedPayload,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
|
||||
from ...core import app
|
||||
@@ -13,6 +23,16 @@ from .errors import RunnerExecutionError, RunnerProtocolError
|
||||
# Maximum size for a single result payload (prevent memory exhaustion)
|
||||
MAX_RESULT_SIZE_BYTES = 1024 * 1024 # 1 MB
|
||||
|
||||
STRICT_RESULT_PAYLOADS: dict[str, type[pydantic.BaseModel]] = {
|
||||
'message.delta': MessageDeltaPayload,
|
||||
'message.completed': MessageCompletedPayload,
|
||||
'state.updated': StateUpdatedPayload,
|
||||
'artifact.created': ArtifactCreatedPayload,
|
||||
'action.requested': ActionRequestedPayload,
|
||||
'run.completed': RunCompletedPayload,
|
||||
'run.failed': RunFailedPayload,
|
||||
}
|
||||
|
||||
|
||||
class AgentResultNormalizer:
|
||||
"""Normalizer for converting AgentRunResult to Pipeline messages.
|
||||
@@ -87,6 +107,9 @@ class AgentResultNormalizer:
|
||||
# Handle each result type
|
||||
data = result_dict.get('data', {})
|
||||
|
||||
if not self.validate_payload(result_type, data, descriptor):
|
||||
return None
|
||||
|
||||
if result_type == 'message.delta':
|
||||
return self._normalize_message_delta(data, descriptor)
|
||||
|
||||
@@ -160,6 +183,31 @@ class AgentResultNormalizer:
|
||||
)
|
||||
return None
|
||||
|
||||
def validate_payload(
|
||||
self,
|
||||
result_type: str,
|
||||
data: typing.Any,
|
||||
descriptor: AgentRunnerDescriptor,
|
||||
) -> bool:
|
||||
"""Validate typed payloads that affect Host state or delivery.
|
||||
|
||||
Tool-call telemetry stays intentionally loose so older runners can keep
|
||||
emitting diagnostic fields. Unknown result types are handled by the
|
||||
caller and are not validated here.
|
||||
"""
|
||||
payload_model = STRICT_RESULT_PAYLOADS.get(result_type)
|
||||
if payload_model is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
payload_model.model_validate(data)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(
|
||||
f'Runner {descriptor.id} returned invalid {result_type} payload; dropping result: {e}'
|
||||
)
|
||||
return False
|
||||
|
||||
def _normalize_message_delta(
|
||||
self,
|
||||
data: dict[str, typing.Any],
|
||||
|
||||
@@ -25,7 +25,7 @@ class RunAuthorizationSnapshot(typing.TypedDict):
|
||||
"""
|
||||
|
||||
resources: AgentResources
|
||||
permissions: dict[str, list[str]]
|
||||
available_apis: dict[str, bool]
|
||||
conversation_id: str | None
|
||||
state_policy: dict[str, typing.Any]
|
||||
state_context: dict[str, typing.Any]
|
||||
@@ -80,7 +80,7 @@ class AgentRunSessionRegistry:
|
||||
plugin_identity: str,
|
||||
resources: AgentResources,
|
||||
conversation_id: str | None = None,
|
||||
permissions: dict[str, list[str]] | None = None,
|
||||
available_apis: dict[str, bool] | None = None,
|
||||
state_policy: dict[str, typing.Any] | None = None,
|
||||
state_context: dict[str, typing.Any] | None = None,
|
||||
) -> None:
|
||||
@@ -93,14 +93,13 @@ class AgentRunSessionRegistry:
|
||||
plugin_identity: Plugin identifier (author/name)
|
||||
resources: Authorized resources for this run
|
||||
conversation_id: Conversation ID for history/event access
|
||||
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
|
||||
available_apis: Run-scoped pull APIs exposed in AgentRunContext
|
||||
state_policy: State policy from binding (enable_state, state_scopes)
|
||||
state_context: Context for state API (scope_keys, binding_identity, etc.)
|
||||
"""
|
||||
now = int(time.time())
|
||||
|
||||
# Normalize permissions to empty dict if None
|
||||
permissions = permissions or {}
|
||||
available_apis = copy.deepcopy(available_apis or {})
|
||||
|
||||
# Normalize state_policy to defaults if None
|
||||
if state_policy is None:
|
||||
@@ -112,7 +111,7 @@ class AgentRunSessionRegistry:
|
||||
resources_snapshot = copy.deepcopy(resources)
|
||||
authorization: RunAuthorizationSnapshot = {
|
||||
'resources': resources_snapshot,
|
||||
'permissions': copy.deepcopy(permissions),
|
||||
'available_apis': available_apis,
|
||||
'conversation_id': conversation_id,
|
||||
'state_policy': copy.deepcopy(state_policy),
|
||||
'state_context': copy.deepcopy(state_context),
|
||||
|
||||
@@ -732,8 +732,8 @@ class BoxService:
|
||||
def get_system_guidance(self) -> str:
|
||||
"""Return LLM system-prompt guidance for the exec tool.
|
||||
|
||||
All execution-specific prompt text is kept here so that callers
|
||||
(e.g. LocalAgentRunner) stay free of box domain knowledge.
|
||||
All execution-specific prompt text is kept here so that callers stay
|
||||
free of box domain knowledge.
|
||||
"""
|
||||
guidance = (
|
||||
'When the exec tool is available, use it for exact calculations, statistics, structured data parsing, '
|
||||
|
||||
@@ -80,7 +80,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
def _runner_accepts_multimodal_input(self, descriptor: AgentRunnerDescriptor | None) -> bool:
|
||||
if descriptor is None:
|
||||
return True
|
||||
return descriptor.capabilities.get('multimodal_input', False)
|
||||
return descriptor.capabilities.multimodal_input
|
||||
|
||||
def _model_supports_vision(self, llm_model: typing.Any | None) -> bool:
|
||||
if not llm_model:
|
||||
|
||||
@@ -184,10 +184,9 @@ async def _validate_agent_run_session(
|
||||
caller_plugin_identity: str | None,
|
||||
ap: app.Application,
|
||||
api_name: str,
|
||||
permission_group: str | None = None,
|
||||
permission_operation: str | None = None,
|
||||
api_capability: str | None = None,
|
||||
) -> Union[tuple[None, handler.ActionResponse], tuple[Any, None]]:
|
||||
"""Validate an AgentRunner pull API run session and optional manifest permission."""
|
||||
"""Validate an AgentRunner pull API run session and run-scoped API access."""
|
||||
session_registry = get_session_registry()
|
||||
session = await session_registry.get(run_id)
|
||||
if not session:
|
||||
@@ -210,10 +209,9 @@ async def _validate_agent_run_session(
|
||||
message=f'Plugin identity mismatch for run_id {run_id}'
|
||||
)
|
||||
|
||||
if permission_group and permission_operation:
|
||||
permissions = _get_run_authorization(session)['permissions']
|
||||
allowed_operations = permissions.get(permission_group, [])
|
||||
if permission_operation not in allowed_operations:
|
||||
if api_capability:
|
||||
available_apis = _get_run_authorization(session).get('available_apis', {})
|
||||
if not available_apis.get(api_capability, False):
|
||||
return None, handler.ActionResponse.error(
|
||||
message=f'{api_name} access not authorized'
|
||||
)
|
||||
@@ -1489,8 +1487,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'History page',
|
||||
permission_group='history',
|
||||
permission_operation='page',
|
||||
api_capability='history_page',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -1560,8 +1557,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'History search',
|
||||
permission_group='history',
|
||||
permission_operation='search',
|
||||
api_capability='history_search',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -1625,8 +1621,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'Event get',
|
||||
permission_group='events',
|
||||
permission_operation='get',
|
||||
api_capability='event_get',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -1678,8 +1673,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'Event page',
|
||||
permission_group='events',
|
||||
permission_operation='page',
|
||||
api_capability='event_page',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -1749,8 +1743,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'Artifact metadata',
|
||||
permission_group='artifacts',
|
||||
permission_operation='metadata',
|
||||
api_capability='artifact_metadata',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -1820,8 +1813,7 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
caller_plugin_identity,
|
||||
self.ap,
|
||||
'Artifact read',
|
||||
permission_group='artifacts',
|
||||
permission_operation='read',
|
||||
api_capability='artifact_read',
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
@@ -2218,8 +2210,6 @@ class RuntimeConnectionHandler(handler.Handler):
|
||||
- runner_name
|
||||
- runner_description
|
||||
- manifest
|
||||
- capabilities
|
||||
- permissions
|
||||
- config
|
||||
"""
|
||||
result = await self.call_action(
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
preregistered_runners: list[typing.Type[RequestRunner]] = []
|
||||
|
||||
|
||||
def runner_class(name: str):
|
||||
"""注册一个请求运行器"""
|
||||
|
||||
def decorator(cls: typing.Type[RequestRunner]) -> typing.Type[RequestRunner]:
|
||||
cls.name = name
|
||||
preregistered_runners.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RequestRunner(abc.ABC):
|
||||
"""请求运行器"""
|
||||
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
|
||||
pipeline_config: dict
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message | provider_message.MessageChunk, None]:
|
||||
"""运行请求"""
|
||||
pass
|
||||
@@ -1,295 +0,0 @@
|
||||
"""
|
||||
Legacy Coze API Runner.
|
||||
|
||||
DEPRECATED: This runner has been migrated to the AgentRunner plugin format.
|
||||
Use the official `langbot/coze-agent` plugin instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import base64
|
||||
|
||||
from langbot.pkg.provider import runner
|
||||
from langbot.pkg.core import app
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
from langbot.pkg.utils import image
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from langbot.libs.coze_server_api.client import AsyncCozeAPIClient
|
||||
|
||||
|
||||
@runner.runner_class('coze-api')
|
||||
class CozeAPIRunner(runner.RequestRunner):
|
||||
"""Coze API 对话请求器"""
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.pipeline_config = pipeline_config
|
||||
self.ap = ap
|
||||
self.agent_token = pipeline_config['ai']['coze-api']['api-key']
|
||||
self.bot_id = pipeline_config['ai']['coze-api'].get('bot-id')
|
||||
self.chat_timeout = pipeline_config['ai']['coze-api'].get('timeout')
|
||||
self.auto_save_history = pipeline_config['ai']['coze-api'].get('auto_save_history')
|
||||
self.api_base = pipeline_config['ai']['coze-api'].get('api-base')
|
||||
|
||||
self.coze = AsyncCozeAPIClient(self.agent_token, self.api_base)
|
||||
|
||||
def _process_thinking_content(
|
||||
self,
|
||||
content: str,
|
||||
) -> tuple[str, str]:
|
||||
"""处理思维链内容
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
Returns:
|
||||
(处理后的内容, 提取的思维链内容)
|
||||
"""
|
||||
remove_think = self.pipeline_config.get('output', {}).get('misc', {}).get('remove-think', False)
|
||||
thinking_content = ''
|
||||
# 从 content 中提取 <think> 标签内容
|
||||
if content and '<think>' in content and '</think>' in content:
|
||||
import re
|
||||
|
||||
think_pattern = r'<think>(.*?)</think>'
|
||||
think_matches = re.findall(think_pattern, content, re.DOTALL)
|
||||
if think_matches:
|
||||
thinking_content = '\n'.join(think_matches)
|
||||
# 移除 content 中的 <think> 标签
|
||||
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
|
||||
|
||||
# 根据 remove_think 参数决定是否保留思维链
|
||||
if remove_think:
|
||||
return content, ''
|
||||
else:
|
||||
# 如果有思维链内容,将其以 <think> 格式添加到 content 开头
|
||||
if thinking_content:
|
||||
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
||||
return content, thinking_content
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> list[dict]:
|
||||
"""预处理用户消息,转换为Coze消息格式
|
||||
|
||||
Returns:
|
||||
list[dict]: Coze消息列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
if isinstance(query.user_message.content, list):
|
||||
# 多模态消息处理
|
||||
content_parts = []
|
||||
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
content_parts.append({'type': 'text', 'text': ce.text})
|
||||
elif ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
file_bytes = base64.b64decode(image_b64)
|
||||
file_id = await self._get_file_id(file_bytes)
|
||||
content_parts.append({'type': 'image', 'file_id': file_id})
|
||||
elif ce.type == 'file':
|
||||
# 处理文件,上传到Coze
|
||||
file_id = await self._get_file_id(ce.file)
|
||||
content_parts.append({'type': 'file', 'file_id': file_id})
|
||||
|
||||
# 创建多模态消息
|
||||
if content_parts:
|
||||
messages.append(
|
||||
{
|
||||
'role': 'user',
|
||||
'content': json.dumps(content_parts),
|
||||
'content_type': 'object_string',
|
||||
'meta_data': None,
|
||||
}
|
||||
)
|
||||
|
||||
elif isinstance(query.user_message.content, str):
|
||||
# 纯文本消息
|
||||
messages.append(
|
||||
{'role': 'user', 'content': query.user_message.content, 'content_type': 'text', 'meta_data': None}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
async def _get_file_id(self, file) -> str:
|
||||
"""上传文件到Coze服务
|
||||
Args:
|
||||
file: 文件
|
||||
Returns:
|
||||
str: 文件ID
|
||||
"""
|
||||
file_id = await self.coze.upload(file=file)
|
||||
return file_id
|
||||
|
||||
async def _chat_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用聊天助手(非流式)
|
||||
|
||||
注意:由于cozepy没有提供非流式API,这里使用流式API并在结束后一次性返回完整内容
|
||||
"""
|
||||
user_id = f'{query.launcher_type.value}_{query.launcher_id}'
|
||||
|
||||
# 预处理用户消息
|
||||
additional_messages = await self._preprocess_user_message(query)
|
||||
|
||||
# 获取会话ID
|
||||
conversation_id = None
|
||||
|
||||
# 收集完整内容
|
||||
full_content = ''
|
||||
full_reasoning = ''
|
||||
|
||||
try:
|
||||
# 调用Coze API流式接口
|
||||
async for chunk in self.coze.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
timeout=self.chat_timeout,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
):
|
||||
self.ap.logger.debug(f'coze-chat-stream: {chunk}')
|
||||
|
||||
event_type = chunk.get('event')
|
||||
data = chunk.get('data', {})
|
||||
# Removed debug print statement to avoid cluttering logs in production
|
||||
|
||||
if event_type == 'conversation.message.delta':
|
||||
# 收集内容
|
||||
if 'content' in data:
|
||||
full_content += data.get('content', '')
|
||||
|
||||
# 收集推理内容(如果有)
|
||||
if 'reasoning_content' in data:
|
||||
full_reasoning += data.get('reasoning_content', '')
|
||||
|
||||
elif event_type.split('.')[-1] == 'done': # 本地部署coze时,结束event不为done
|
||||
# 保存会话ID
|
||||
if 'conversation_id' in data:
|
||||
conversation_id = data.get('conversation_id')
|
||||
|
||||
elif event_type == 'error':
|
||||
# 处理错误
|
||||
error_msg = f'Coze API错误: {data.get("message", "未知错误")}'
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=error_msg,
|
||||
)
|
||||
return
|
||||
|
||||
# 处理思维链内容
|
||||
content, thinking_content = self._process_thinking_content(full_content)
|
||||
if full_reasoning:
|
||||
remove_think = self.pipeline_config.get('output', {}).get('misc', {}).get('remove-think', False)
|
||||
if not remove_think:
|
||||
content = f'<think>\n{full_reasoning}\n</think>\n{content}'.strip()
|
||||
|
||||
# 一次性返回完整内容
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=content,
|
||||
)
|
||||
|
||||
# 保存会话ID
|
||||
if conversation_id and query.session.using_conversation:
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Coze API错误: {str(e)}')
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=f'Coze API调用失败: {str(e)}',
|
||||
)
|
||||
|
||||
async def _chat_messages_chunk(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""调用聊天助手(流式)"""
|
||||
user_id = f'{query.launcher_type.value}_{query.launcher_id}'
|
||||
|
||||
# 预处理用户消息
|
||||
additional_messages = await self._preprocess_user_message(query)
|
||||
|
||||
# 获取会话ID
|
||||
conversation_id = None
|
||||
|
||||
start_reasoning = False
|
||||
stop_reasoning = False
|
||||
message_idx = 1
|
||||
is_final = False
|
||||
full_content = ''
|
||||
remove_think = self.pipeline_config.get('output', {}).get('misc', {}).get('remove-think', False)
|
||||
|
||||
try:
|
||||
# 调用Coze API流式接口
|
||||
async for chunk in self.coze.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
timeout=self.chat_timeout,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
):
|
||||
self.ap.logger.debug(f'coze-chat-stream-chunk: {chunk}')
|
||||
|
||||
event_type = chunk.get('event')
|
||||
data = chunk.get('data', {})
|
||||
content = ''
|
||||
|
||||
if event_type == 'conversation.message.delta':
|
||||
message_idx += 1
|
||||
# 处理内容增量
|
||||
if 'reasoning_content' in data and not remove_think:
|
||||
reasoning_content = data.get('reasoning_content', '')
|
||||
if reasoning_content and not start_reasoning:
|
||||
content = '<think/>\n'
|
||||
start_reasoning = True
|
||||
content += reasoning_content
|
||||
|
||||
if 'content' in data:
|
||||
if data.get('content', ''):
|
||||
content += data.get('content', '')
|
||||
if not stop_reasoning and start_reasoning:
|
||||
content = f'</think>\n{content}'
|
||||
stop_reasoning = True
|
||||
|
||||
elif event_type.split('.')[-1] == 'done': # 本地部署coze时,结束event不为done
|
||||
# 保存会话ID
|
||||
if 'conversation_id' in data:
|
||||
conversation_id = data.get('conversation_id')
|
||||
if query.session.using_conversation:
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
is_final = True
|
||||
|
||||
elif event_type == 'error':
|
||||
# 处理错误
|
||||
error_msg = f'Coze API错误: {data.get("message", "未知错误")}'
|
||||
yield provider_message.MessageChunk(role='assistant', content=error_msg, finish_reason='error')
|
||||
return
|
||||
full_content += content
|
||||
if message_idx % 8 == 0 or is_final:
|
||||
if full_content:
|
||||
yield provider_message.MessageChunk(role='assistant', content=full_content, is_final=is_final)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Coze API流式调用错误: {str(e)}')
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant', content=f'Coze API流式调用失败: {str(e)}', finish_reason='error'
|
||||
)
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行"""
|
||||
msg_seq = 0
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
async for msg in self._chat_messages_chunk(query):
|
||||
if isinstance(msg, provider_message.MessageChunk):
|
||||
msg_seq += 1
|
||||
msg.msg_sequence = msg_seq
|
||||
yield msg
|
||||
else:
|
||||
async for msg in self._chat_messages(query):
|
||||
yield msg
|
||||
@@ -1,362 +0,0 @@
|
||||
"""
|
||||
Legacy DashScope (阿里云百炼) API Runner.
|
||||
|
||||
DEPRECATED: This runner has been migrated to the AgentRunner plugin format.
|
||||
Use the official `langbot/dashscope-agent` plugin instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import re
|
||||
|
||||
import dashscope
|
||||
|
||||
from .. import runner
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
class DashscopeAPIError(Exception):
|
||||
"""Dashscope API 请求失败"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@runner.runner_class('dashscope-app-api')
|
||||
class DashScopeAPIRunner(runner.RequestRunner):
|
||||
"阿里云百炼DashsscopeAPI对话请求器"
|
||||
|
||||
# 运行器内部使用的配置
|
||||
app_type: str # 应用类型
|
||||
app_id: str # 应用ID
|
||||
api_key: str # API Key
|
||||
references_quote: (
|
||||
str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置)
|
||||
)
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
"""初始化"""
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
valid_app_types = ['agent', 'workflow']
|
||||
self.app_type = self.pipeline_config['ai']['dashscope-app-api']['app-type']
|
||||
# 检查配置文件中使用的应用类型是否支持
|
||||
if self.app_type not in valid_app_types:
|
||||
raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}')
|
||||
|
||||
# 初始化Dashscope 参数配置
|
||||
self.app_id = self.pipeline_config['ai']['dashscope-app-api']['app-id']
|
||||
self.api_key = self.pipeline_config['ai']['dashscope-app-api']['api-key']
|
||||
self.references_quote = self.pipeline_config['ai']['dashscope-app-api']['references_quote']
|
||||
|
||||
def _replace_references(self, text, references_dict):
|
||||
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
|
||||
|
||||
# 匹配 <ref>[index_id]</ref> 形式的字符串
|
||||
pattern = re.compile(r'<ref>\[(.*?)\]</ref>')
|
||||
|
||||
def replacement(match):
|
||||
# 获取引用编号
|
||||
ref_key = match.group(1)
|
||||
if ref_key in references_dict:
|
||||
# 如果有对应的参考资料按照provider.json中的reference_quote返回提示,来自哪个参考资料文件
|
||||
return f'({self.references_quote} {references_dict[ref_key]})'
|
||||
else:
|
||||
# 如果没有对应的参考资料,保留原样
|
||||
return match.group(0)
|
||||
|
||||
# 使用 re.sub() 进行替换
|
||||
return pattern.sub(replacement, text)
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
|
||||
"""预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)"""
|
||||
plain_text = ''
|
||||
image_ids = []
|
||||
if isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
# 暂时不支持上传图片,保留代码以便后续扩展
|
||||
# elif ce.type == "image_base64":
|
||||
# image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
# file_bytes = base64.b64decode(image_b64)
|
||||
# file = ("img.png", file_bytes, f"image/{image_format}")
|
||||
# file_upload_resp = await self.dify_client.upload_file(
|
||||
# file,
|
||||
# f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
# )
|
||||
# image_id = file_upload_resp["id"]
|
||||
# image_ids.append(image_id)
|
||||
elif isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
|
||||
return plain_text, image_ids
|
||||
|
||||
async def _agent_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""Dashscope 智能体对话请求"""
|
||||
|
||||
# 局部变量
|
||||
chunk = None # 流式传输的块
|
||||
pending_content = '' # 待处理的Agent输出内容
|
||||
references_dict = {} # 用于存储引用编号和对应的参考资料
|
||||
plain_text = '' # 用户输入的纯文本信息
|
||||
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
|
||||
|
||||
think_start = False
|
||||
think_end = False
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
has_thoughts = True # 获取思考过程
|
||||
remove_think = self.pipeline_config['output'].get('misc', {}).get('remove-think')
|
||||
if remove_think:
|
||||
has_thoughts = False
|
||||
# 发送对话请求
|
||||
response = dashscope.Application.call(
|
||||
api_key=self.api_key, # 智能体应用的API Key
|
||||
app_id=self.app_id, # 智能体应用的ID
|
||||
prompt=plain_text, # 用户输入的文本信息
|
||||
stream=True, # 流式输出
|
||||
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
|
||||
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
|
||||
enable_thinking=has_thoughts,
|
||||
has_thoughts=has_thoughts,
|
||||
# rag_options={ # 主要用于文件交互,暂不支持
|
||||
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
|
||||
# }
|
||||
)
|
||||
idx_chunk = 0
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
if is_stream:
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
idx_chunk += 1
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
stream_think = stream_output.get('thoughts') or []
|
||||
if stream_think and stream_think[0].get('thought'):
|
||||
if not think_start:
|
||||
think_start = True
|
||||
pending_content += f'<think>\n{stream_think[0].get("thought")}'
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
pending_content += stream_think[0].get('thought')
|
||||
elif think_start and (not stream_think or stream_think[0].get('thought') == '') and not think_end:
|
||||
think_end = True
|
||||
pending_content += '\n</think>\n'
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
# 是否是流式最后一个chunk
|
||||
is_final = False if stream_output.get('finish_reason', False) == 'null' else True
|
||||
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
|
||||
if idx_chunk % 8 == 0 or is_final:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=is_final,
|
||||
)
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
else:
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
idx_chunk += 1
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
stream_think = stream_output.get('thoughts') or []
|
||||
if stream_think and stream_think[0].get('thought'):
|
||||
if not think_start:
|
||||
think_start = True
|
||||
pending_content += f'<think>\n{stream_think[0].get("thought")}'
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
pending_content += stream_think[0].get('thought')
|
||||
elif think_start and (not stream_think or stream_think[0].get('thought') == '') and not think_end:
|
||||
think_end = True
|
||||
pending_content += '\n</think>\n'
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
async def _workflow_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""Dashscope 工作流对话请求"""
|
||||
|
||||
# 局部变量
|
||||
chunk = None # 流式传输的块
|
||||
pending_content = '' # 待处理的Agent输出内容
|
||||
references_dict = {} # 用于存储引用编号和对应的参考资料
|
||||
plain_text = '' # 用户输入的纯文本信息
|
||||
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
biz_params = {}
|
||||
biz_params.update(query.variables)
|
||||
|
||||
# 发送对话请求
|
||||
response = dashscope.Application.call(
|
||||
api_key=self.api_key, # 智能体应用的API Key
|
||||
app_id=self.app_id, # 智能体应用的ID
|
||||
prompt=plain_text, # 用户输入的文本信息
|
||||
stream=True, # 流式输出
|
||||
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
|
||||
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
|
||||
biz_params=biz_params, # 工作流应用的自定义输入参数传递
|
||||
flow_stream_mode='message_format', # 消息模式,输出/结束节点的流式结果
|
||||
# rag_options={ # 主要用于文件交互,暂不支持
|
||||
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
|
||||
# }
|
||||
)
|
||||
|
||||
# 处理API返回的流式输出
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
idx_chunk = 0
|
||||
if is_stream:
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
idx_chunk += 1
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
if stream_output.get('workflow_message') is not None:
|
||||
pending_content += stream_output.get('workflow_message').get('message').get('content')
|
||||
# if stream_output.get('text') is not None:
|
||||
# pending_content += stream_output.get('text')
|
||||
|
||||
is_final = False if stream_output.get('finish_reason', False) == 'null' else True
|
||||
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
if idx_chunk % 8 == 0 or is_final:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
|
||||
else:
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
|
||||
is_final = False if stream_output.get('finish_reason', False) == 'null' else True
|
||||
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行"""
|
||||
msg_seq = 0
|
||||
if self.app_type == 'agent':
|
||||
async for msg in self._agent_messages(query):
|
||||
if isinstance(msg, provider_message.MessageChunk):
|
||||
msg_seq += 1
|
||||
msg.msg_sequence = msg_seq
|
||||
yield msg
|
||||
elif self.app_type == 'workflow':
|
||||
async for msg in self._workflow_messages(query):
|
||||
if isinstance(msg, provider_message.MessageChunk):
|
||||
msg_seq += 1
|
||||
msg.msg_sequence = msg_seq
|
||||
yield msg
|
||||
else:
|
||||
raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}')
|
||||
@@ -1,782 +0,0 @@
|
||||
"""
|
||||
Legacy Dify Service API Runner.
|
||||
|
||||
DEPRECATED: This runner has been migrated to the AgentRunner plugin format.
|
||||
Use the official `langbot/dify-agent` plugin instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import mimetypes
|
||||
|
||||
|
||||
from langbot.pkg.provider import runner
|
||||
from langbot.pkg.core import app
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
from langbot.pkg.utils import image
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from langbot.libs.dify_service_api.v1 import client, errors
|
||||
import httpx
|
||||
|
||||
|
||||
@runner.runner_class('dify-service-api')
|
||||
class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
"""Dify Service API 对话请求器"""
|
||||
|
||||
dify_client: client.AsyncDifyServiceClient
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
valid_app_types = ['chat', 'agent', 'workflow']
|
||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] not in valid_app_types:
|
||||
raise errors.DifyAPIError(
|
||||
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
|
||||
)
|
||||
|
||||
api_key = self.pipeline_config['ai']['dify-service-api']['api-key']
|
||||
|
||||
self.dify_client = client.AsyncDifyServiceClient(
|
||||
api_key=api_key,
|
||||
base_url=self.pipeline_config['ai']['dify-service-api']['base-url'],
|
||||
)
|
||||
|
||||
def _process_thinking_content(
|
||||
self,
|
||||
content: str,
|
||||
) -> tuple[str, str]:
|
||||
"""处理思维链内容
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
Returns:
|
||||
(处理后的内容, 提取的思维链内容)
|
||||
"""
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
thinking_content = ''
|
||||
# 从 content 中提取 <think> 标签内容
|
||||
if content and '<think>' in content and '</think>' in content:
|
||||
import re
|
||||
|
||||
think_pattern = r'<think>(.*?)</think>'
|
||||
think_matches = re.findall(think_pattern, content, re.DOTALL)
|
||||
if think_matches:
|
||||
thinking_content = '\n'.join(think_matches)
|
||||
# 移除 content 中的 <think> 标签
|
||||
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
|
||||
|
||||
# 3. 根据 remove_think 参数决定是否保留思维链
|
||||
if remove_think:
|
||||
return content, ''
|
||||
else:
|
||||
# 如果有思维链内容,将其以 <think> 格式添加到 content 开头
|
||||
if thinking_content:
|
||||
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
||||
return content, thinking_content
|
||||
|
||||
def _extract_dify_text_output(self, value: typing.Any) -> str:
|
||||
"""Extract text content from Dify output payload."""
|
||||
if value is None:
|
||||
return ''
|
||||
if isinstance(value, dict):
|
||||
content = value.get('content')
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return ''
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
if isinstance(parsed, dict) and isinstance(parsed.get('content'), str):
|
||||
return parsed['content']
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[dict]]:
|
||||
"""预处理用户消息,提取纯文本,并将图片/文件上传到 Dify 服务
|
||||
|
||||
Returns:
|
||||
tuple[str, list[dict]]: 纯文本和上传后的文件描述(包含 type 与 id)
|
||||
"""
|
||||
plain_text = ''
|
||||
upload_files: list[dict] = []
|
||||
user_tag = f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||
|
||||
async def upload_file_bytes(file_name: str, file_bytes: bytes, content_type: str) -> str:
|
||||
file_name = file_name or 'file'
|
||||
content_type = content_type or 'application/octet-stream'
|
||||
file = (file_name, file_bytes, content_type)
|
||||
resp = await self.dify_client.upload_file(file, user_tag)
|
||||
return resp['id']
|
||||
|
||||
async def download_file(file_url: str) -> tuple[bytes, str]:
|
||||
"""Download file from url (supports data url)."""
|
||||
|
||||
async with httpx.AsyncClient() as client_session:
|
||||
resp = await client_session.get(file_url)
|
||||
resp.raise_for_status()
|
||||
content_type = (
|
||||
resp.headers.get('content-type') or mimetypes.guess_type(file_url)[0] or 'application/octet-stream'
|
||||
)
|
||||
return resp.content, content_type
|
||||
|
||||
def _detect_file_type(content_type: str) -> str:
|
||||
"""Map MIME to dify file type."""
|
||||
if content_type and content_type.startswith('image/'):
|
||||
return 'image'
|
||||
if content_type and content_type.startswith('audio/'):
|
||||
return 'audio'
|
||||
if content_type and content_type.startswith('video/'):
|
||||
return 'video'
|
||||
return 'document'
|
||||
|
||||
if isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
elif ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
file_bytes = base64.b64decode(image_b64)
|
||||
image_id = await upload_file_bytes(f'img.{image_format}', file_bytes, f'image/{image_format}')
|
||||
upload_files.append({'type': 'image', 'id': image_id})
|
||||
elif ce.type == 'file_url':
|
||||
file_url = getattr(ce, 'file_url', None)
|
||||
file_name = getattr(ce, 'file_name', None) or 'file'
|
||||
try:
|
||||
file_bytes, content_type = await download_file(file_url)
|
||||
file_id = await upload_file_bytes(file_name, file_bytes, content_type)
|
||||
file_type = _detect_file_type(content_type)
|
||||
upload_files.append({'type': file_type, 'id': file_id})
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'dify file upload failed: {e}')
|
||||
elif ce.type == 'file_base64':
|
||||
file_name = getattr(ce, 'file_name', None) or 'file'
|
||||
|
||||
header, b64_data = ce.file_base64.split(',', 1)
|
||||
content_type = 'application/octet-stream'
|
||||
if ';' in header:
|
||||
content_type = header.split(';')[0][5:] or content_type
|
||||
file_bytes = base64.b64decode(b64_data)
|
||||
file_id = await upload_file_bytes(file_name, file_bytes, content_type)
|
||||
file_type = _detect_file_type(content_type)
|
||||
upload_files.append({'type': file_type, 'id': file_id})
|
||||
|
||||
elif isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
|
||||
plain_text = plain_text if plain_text else self.pipeline_config['ai']['dify-service-api']['base-prompt']
|
||||
|
||||
return plain_text, upload_files
|
||||
|
||||
async def _chat_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用聊天助手"""
|
||||
cov_id = query.session.using_conversation.uuid or None
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
mode = 'basic' # 标记是基础编排还是工作流编排
|
||||
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
inputs = {}
|
||||
|
||||
inputs.update(query.variables)
|
||||
|
||||
chunk = None # 初始化chunk变量,防止在没有响应时引用错误
|
||||
|
||||
async for chunk in self.dify_client.chat_messages(
|
||||
inputs=inputs,
|
||||
query=plain_text,
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
|
||||
|
||||
if chunk['event'] == 'workflow_started':
|
||||
mode = 'workflow'
|
||||
|
||||
if mode == 'workflow':
|
||||
if chunk['event'] == 'node_finished':
|
||||
if chunk['data']['node_type'] == 'answer':
|
||||
answer = self._extract_dify_text_output(chunk['data']['outputs'].get('answer'))
|
||||
content, _ = self._process_thinking_content(answer)
|
||||
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=content,
|
||||
)
|
||||
elif mode == 'basic':
|
||||
if chunk['event'] == 'message':
|
||||
basic_mode_pending_chunk += chunk['answer']
|
||||
elif chunk['event'] == 'message_end':
|
||||
content, _ = self._process_thinking_content(basic_mode_pending_chunk)
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=content,
|
||||
)
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _agent_chat_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用聊天助手"""
|
||||
cov_id = query.session.using_conversation.uuid or None
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
ignored_events = []
|
||||
|
||||
inputs = {}
|
||||
|
||||
inputs.update(query.variables)
|
||||
|
||||
pending_agent_message = ''
|
||||
|
||||
chunk = None # 初始化chunk变量,防止在没有响应时引用错误
|
||||
|
||||
async for chunk in self.dify_client.chat_messages(
|
||||
inputs=inputs,
|
||||
query=plain_text,
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
response_mode='streaming',
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-agent-chunk: ' + str(chunk))
|
||||
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'agent_message' or chunk['event'] == 'message':
|
||||
pending_agent_message += chunk['answer']
|
||||
else:
|
||||
if pending_agent_message.strip() != '':
|
||||
pending_agent_message = pending_agent_message.replace('</details>Action:', '</details>')
|
||||
content, _ = self._process_thinking_content(pending_agent_message)
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=content,
|
||||
)
|
||||
pending_agent_message = ''
|
||||
|
||||
if chunk['event'] == 'agent_thought':
|
||||
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
|
||||
continue
|
||||
|
||||
if chunk['tool']:
|
||||
msg = provider_message.Message(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id=chunk['id'],
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name=chunk['tool'],
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
yield msg
|
||||
if chunk['event'] == 'message_file':
|
||||
if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant':
|
||||
# 检查URL是否已经是完整的连接
|
||||
if chunk['url'].startswith('http://') or chunk['url'].startswith('https://'):
|
||||
image_url = chunk['url']
|
||||
else:
|
||||
base_url = self.dify_client.base_url
|
||||
|
||||
if base_url.endswith('/v1'):
|
||||
base_url = base_url[:-3]
|
||||
|
||||
image_url = base_url + chunk['url']
|
||||
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=[provider_message.ContentElement.from_image_url(image_url)],
|
||||
)
|
||||
if chunk['event'] == 'error':
|
||||
raise errors.DifyAPIError('dify 服务错误: ' + chunk['message'])
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _workflow_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用工作流"""
|
||||
|
||||
if not query.session.using_conversation.uuid:
|
||||
query.session.using_conversation.uuid = str(uuid.uuid4())
|
||||
|
||||
query.variables['conversation_id'] = query.session.using_conversation.uuid
|
||||
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
ignored_events = ['text_chunk', 'workflow_started']
|
||||
|
||||
inputs = { # these variables are legacy variables, we need to keep them for compatibility
|
||||
'langbot_user_message_text': plain_text,
|
||||
'langbot_session_id': query.variables['session_id'],
|
||||
'langbot_conversation_id': query.variables['conversation_id'],
|
||||
'langbot_msg_create_time': query.variables['msg_create_time'],
|
||||
}
|
||||
|
||||
inputs.update(query.variables)
|
||||
|
||||
async for chunk in self.dify_client.workflow_run(
|
||||
inputs=inputs,
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
files=files,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'node_started':
|
||||
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
|
||||
continue
|
||||
|
||||
msg = provider_message.Message(
|
||||
role='assistant',
|
||||
content=None,
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id=chunk['data']['node_id'],
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name=chunk['data']['title'],
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
elif chunk['event'] == 'workflow_finished':
|
||||
if chunk['data']['error']:
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
|
||||
|
||||
msg = provider_message.Message(
|
||||
role='assistant',
|
||||
content=content,
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
async def _chat_messages_chunk(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""调用聊天助手"""
|
||||
cov_id = query.session.using_conversation.uuid or None
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
mode = 'basic'
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
inputs = {}
|
||||
|
||||
inputs.update(query.variables)
|
||||
message_idx = 0
|
||||
|
||||
chunk = None # 初始化chunk变量,防止在没有响应时引用错误
|
||||
|
||||
is_final = False
|
||||
think_start = False
|
||||
think_end = False
|
||||
yielded_final = False
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
|
||||
async for chunk in self.dify_client.chat_messages(
|
||||
inputs=inputs,
|
||||
query=plain_text,
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
|
||||
|
||||
if chunk['event'] == 'workflow_started':
|
||||
mode = 'workflow'
|
||||
elif chunk['event'] in ('node_started', 'node_finished', 'workflow_finished'):
|
||||
# Some Dify deployments may omit workflow_started in streamed chunks.
|
||||
mode = 'workflow'
|
||||
|
||||
if chunk['event'] == 'message':
|
||||
message_idx += 1
|
||||
if remove_think:
|
||||
if '<think>' in chunk['answer'] and not think_start:
|
||||
think_start = True
|
||||
continue
|
||||
if '</think>' in chunk['answer'] and not think_end:
|
||||
import re
|
||||
|
||||
content = re.sub(r'^\n</think>', '', chunk['answer'])
|
||||
basic_mode_pending_chunk += content
|
||||
think_end = True
|
||||
elif think_end:
|
||||
basic_mode_pending_chunk += chunk['answer']
|
||||
if think_start:
|
||||
continue
|
||||
|
||||
else:
|
||||
basic_mode_pending_chunk += chunk['answer']
|
||||
|
||||
if chunk['event'] == 'message_end':
|
||||
is_final = True
|
||||
elif chunk['event'] == 'workflow_finished':
|
||||
is_final = True
|
||||
if chunk['data'].get('error'):
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
|
||||
if mode == 'workflow' and chunk['event'] == 'node_finished':
|
||||
if chunk['data'].get('node_type') == 'answer':
|
||||
answer = self._extract_dify_text_output(chunk['data'].get('outputs', {}).get('answer'))
|
||||
if answer:
|
||||
basic_mode_pending_chunk = answer
|
||||
|
||||
if (
|
||||
not yielded_final
|
||||
and (is_final or message_idx % 8 == 0)
|
||||
and (basic_mode_pending_chunk != '' or is_final)
|
||||
):
|
||||
# content, _ = self._process_thinking_content(basic_mode_pending_chunk)
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=basic_mode_pending_chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
if is_final:
|
||||
yielded_final = True
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _agent_chat_messages_chunk(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""调用聊天助手"""
|
||||
cov_id = query.session.using_conversation.uuid or None
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
ignored_events = []
|
||||
|
||||
inputs = {}
|
||||
|
||||
inputs.update(query.variables)
|
||||
|
||||
pending_agent_message = ''
|
||||
|
||||
chunk = None # 初始化chunk变量,防止在没有响应时引用错误
|
||||
message_idx = 0
|
||||
is_final = False
|
||||
think_start = False
|
||||
think_end = False
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
|
||||
async for chunk in self.dify_client.chat_messages(
|
||||
inputs=inputs,
|
||||
query=plain_text,
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
response_mode='streaming',
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-agent-chunk: ' + str(chunk))
|
||||
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'agent_message':
|
||||
message_idx += 1
|
||||
if remove_think:
|
||||
if '<think>' in chunk['answer'] and not think_start:
|
||||
think_start = True
|
||||
continue
|
||||
if '</think>' in chunk['answer'] and not think_end:
|
||||
import re
|
||||
|
||||
content = re.sub(r'^\n</think>', '', chunk['answer'])
|
||||
pending_agent_message += content
|
||||
think_end = True
|
||||
elif think_end or not think_start:
|
||||
pending_agent_message += chunk['answer']
|
||||
if think_start and not think_end:
|
||||
continue
|
||||
|
||||
else:
|
||||
pending_agent_message += chunk['answer']
|
||||
elif chunk['event'] == 'message_end':
|
||||
is_final = True
|
||||
else:
|
||||
if chunk['event'] == 'agent_thought':
|
||||
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
|
||||
continue
|
||||
message_idx += 1
|
||||
if chunk['tool']:
|
||||
msg = provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id=chunk['id'],
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name=chunk['tool'],
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
yield msg
|
||||
if chunk['event'] == 'message_file':
|
||||
message_idx += 1
|
||||
if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant':
|
||||
# 检查URL是否已经是完整的连接
|
||||
if chunk['url'].startswith('http://') or chunk['url'].startswith('https://'):
|
||||
image_url = chunk['url']
|
||||
else:
|
||||
base_url = self.dify_client.base_url
|
||||
|
||||
if base_url.endswith('/v1'):
|
||||
base_url = base_url[:-3]
|
||||
|
||||
image_url = base_url + chunk['url']
|
||||
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=[provider_message.ContentElement.from_image_url(image_url)],
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
if chunk['event'] == 'error':
|
||||
raise errors.DifyAPIError('dify 服务错误: ' + chunk['message'])
|
||||
if message_idx % 8 == 0 or is_final:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_agent_message,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _workflow_messages_chunk(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""调用工作流"""
|
||||
|
||||
if not query.session.using_conversation.uuid:
|
||||
query.session.using_conversation.uuid = str(uuid.uuid4())
|
||||
|
||||
query.variables['conversation_id'] = query.session.using_conversation.uuid
|
||||
|
||||
plain_text, upload_files = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': f['type'],
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': f['id'],
|
||||
}
|
||||
for f in upload_files
|
||||
]
|
||||
|
||||
ignored_events = ['workflow_started']
|
||||
|
||||
inputs = { # these variables are legacy variables, we need to keep them for compatibility
|
||||
'langbot_user_message_text': plain_text,
|
||||
'langbot_session_id': query.variables['session_id'],
|
||||
'langbot_conversation_id': query.variables['conversation_id'],
|
||||
'langbot_msg_create_time': query.variables['msg_create_time'],
|
||||
}
|
||||
|
||||
inputs.update(query.variables)
|
||||
messsage_idx = 0
|
||||
is_final = False
|
||||
think_start = False
|
||||
think_end = False
|
||||
workflow_contents = ''
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
async for chunk in self.dify_client.workflow_run(
|
||||
inputs=inputs,
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
files=files,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
if chunk['event'] == 'workflow_finished':
|
||||
is_final = True
|
||||
if chunk['data']['error']:
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
|
||||
if chunk['event'] == 'text_chunk':
|
||||
messsage_idx += 1
|
||||
if remove_think:
|
||||
if '<think>' in chunk['data']['text'] and not think_start:
|
||||
think_start = True
|
||||
continue
|
||||
if '</think>' in chunk['data']['text'] and not think_end:
|
||||
import re
|
||||
|
||||
content = re.sub(r'^\n</think>', '', chunk['data']['text'])
|
||||
workflow_contents += content
|
||||
think_end = True
|
||||
elif think_end:
|
||||
workflow_contents += chunk['data']['text']
|
||||
if think_start:
|
||||
continue
|
||||
|
||||
else:
|
||||
workflow_contents += chunk['data']['text']
|
||||
|
||||
if chunk['event'] == 'node_started':
|
||||
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
|
||||
continue
|
||||
messsage_idx += 1
|
||||
msg = provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=None,
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id=chunk['data']['node_id'],
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name=chunk['data']['title'],
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
if messsage_idx % 8 == 0 or is_final:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=workflow_contents,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行请求"""
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
msg_idx = 0
|
||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
|
||||
async for msg in self._chat_messages_chunk(query):
|
||||
msg_idx += 1
|
||||
msg.msg_sequence = msg_idx
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent':
|
||||
async for msg in self._agent_chat_messages_chunk(query):
|
||||
msg_idx += 1
|
||||
msg.msg_sequence = msg_idx
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow':
|
||||
async for msg in self._workflow_messages_chunk(query):
|
||||
msg_idx += 1
|
||||
msg.msg_sequence = msg_idx
|
||||
yield msg
|
||||
else:
|
||||
raise errors.DifyAPIError(
|
||||
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
|
||||
)
|
||||
else:
|
||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
|
||||
async for msg in self._chat_messages(query):
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent':
|
||||
async for msg in self._agent_chat_messages(query):
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow':
|
||||
async for msg in self._workflow_messages(query):
|
||||
yield msg
|
||||
else:
|
||||
raise errors.DifyAPIError(
|
||||
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
|
||||
)
|
||||
@@ -1,187 +0,0 @@
|
||||
"""
|
||||
Legacy Langflow API Runner.
|
||||
|
||||
DEPRECATED: This runner has been migrated to the AgentRunner plugin format.
|
||||
Use the official `langbot/langflow-agent` plugin instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import httpx
|
||||
import uuid
|
||||
import traceback
|
||||
|
||||
from .. import runner
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
@runner.runner_class('langflow-api')
|
||||
class LangflowAPIRunner(runner.RequestRunner):
|
||||
"""Langflow API 对话请求器"""
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
async def _build_request_payload(self, query: pipeline_query.Query) -> dict:
|
||||
"""构建请求负载
|
||||
|
||||
Args:
|
||||
query: 用户查询对象
|
||||
|
||||
Returns:
|
||||
dict: 请求负载
|
||||
"""
|
||||
# 获取用户消息文本
|
||||
user_message_text = ''
|
||||
if isinstance(query.user_message.content, str):
|
||||
user_message_text = query.user_message.content
|
||||
elif isinstance(query.user_message.content, list):
|
||||
for item in query.user_message.content:
|
||||
if item.type == 'text':
|
||||
user_message_text += item.text
|
||||
|
||||
# 从配置中获取 input_type 和 output_type,如果未配置则使用默认值
|
||||
input_type = self.pipeline_config['ai']['langflow-api'].get('input_type', 'chat')
|
||||
output_type = self.pipeline_config['ai']['langflow-api'].get('output_type', 'chat')
|
||||
|
||||
# 构建基本负载
|
||||
payload = {
|
||||
'output_type': output_type,
|
||||
'input_type': input_type,
|
||||
'input_value': user_message_text,
|
||||
'session_id': str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# 如果配置中有tweaks,则添加到负载中
|
||||
tweaks = json.loads(self.pipeline_config['ai']['langflow-api'].get('tweaks'))
|
||||
if tweaks:
|
||||
payload['tweaks'] = tweaks
|
||||
|
||||
return payload
|
||||
|
||||
async def run(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message | provider_message.MessageChunk, None]:
|
||||
"""运行请求
|
||||
|
||||
Args:
|
||||
query: 用户查询对象
|
||||
|
||||
Yields:
|
||||
Message: 回复消息
|
||||
"""
|
||||
# 检查是否支持流式输出
|
||||
is_stream = False
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
# 从配置中获取API参数
|
||||
base_url = self.pipeline_config['ai']['langflow-api']['base-url']
|
||||
api_key = self.pipeline_config['ai']['langflow-api']['api-key']
|
||||
flow_id = self.pipeline_config['ai']['langflow-api']['flow-id']
|
||||
|
||||
# 构建API URL
|
||||
url = f'{base_url.rstrip("/")}/api/v1/run/{flow_id}'
|
||||
|
||||
# 构建请求负载
|
||||
payload = await self._build_request_payload(query)
|
||||
|
||||
# 设置请求头
|
||||
headers = {'Content-Type': 'application/json', 'x-api-key': api_key}
|
||||
|
||||
# 发送请求
|
||||
async with httpx.AsyncClient() as client:
|
||||
if is_stream:
|
||||
# 流式请求
|
||||
async with client.stream('POST', url, json=payload, headers=headers, timeout=120.0) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
accumulated_content = ''
|
||||
message_count = 0
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
data_str = line
|
||||
|
||||
if data_str.startswith('data: '):
|
||||
data_str = data_str[6:] # 移除 "data: " 前缀
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
|
||||
# 提取消息内容
|
||||
message_text = ''
|
||||
if 'outputs' in data and len(data['outputs']) > 0:
|
||||
output = data['outputs'][0]
|
||||
if 'outputs' in output and len(output['outputs']) > 0:
|
||||
inner_output = output['outputs'][0]
|
||||
if 'outputs' in inner_output and 'message' in inner_output['outputs']:
|
||||
message_data = inner_output['outputs']['message']
|
||||
if 'message' in message_data:
|
||||
message_text = message_data['message']
|
||||
|
||||
# 如果没有找到消息,尝试其他可能的路径
|
||||
if not message_text and 'messages' in data:
|
||||
messages = data['messages']
|
||||
if messages and len(messages) > 0:
|
||||
message_text = messages[0].get('message', '')
|
||||
|
||||
if message_text:
|
||||
# 更新累积内容
|
||||
accumulated_content = message_text
|
||||
message_count += 1
|
||||
|
||||
# 每8条消息或有新内容时生成一个chunk
|
||||
if message_count % 8 == 0 or len(message_text) > 0:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant', content=accumulated_content, is_final=False
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是JSON,跳过这一行
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# 发送最终消息
|
||||
yield provider_message.MessageChunk(role='assistant', content=accumulated_content, is_final=True)
|
||||
else:
|
||||
# 非流式请求
|
||||
response = await client.post(url, json=payload, headers=headers, timeout=120.0)
|
||||
response.raise_for_status()
|
||||
|
||||
# 解析响应
|
||||
response_data = response.json()
|
||||
|
||||
# 提取消息内容
|
||||
# 根据Langflow API文档,响应结构可能在outputs[0].outputs[0].outputs.message.message中
|
||||
message_text = ''
|
||||
if 'outputs' in response_data and len(response_data['outputs']) > 0:
|
||||
output = response_data['outputs'][0]
|
||||
if 'outputs' in output and len(output['outputs']) > 0:
|
||||
inner_output = output['outputs'][0]
|
||||
if 'outputs' in inner_output and 'message' in inner_output['outputs']:
|
||||
message_data = inner_output['outputs']['message']
|
||||
if 'message' in message_data:
|
||||
message_text = message_data['message']
|
||||
|
||||
# 如果没有找到消息,尝试其他可能的路径
|
||||
if not message_text and 'messages' in response_data:
|
||||
messages = response_data['messages']
|
||||
if messages and len(messages) > 0:
|
||||
message_text = messages[0].get('message', '')
|
||||
|
||||
# 如果仍然没有找到消息,返回完整响应的字符串表示
|
||||
if not message_text:
|
||||
message_text = json.dumps(response_data, ensure_ascii=False, indent=2)
|
||||
|
||||
# 生成回复消息
|
||||
if is_stream:
|
||||
yield provider_message.MessageChunk(role='assistant', content=message_text, is_final=True)
|
||||
else:
|
||||
reply_message = provider_message.Message(role='assistant', content=message_text)
|
||||
yield reply_message
|
||||
@@ -1,520 +0,0 @@
|
||||
"""
|
||||
Legacy Local Agent Runner.
|
||||
|
||||
DEPRECATED: This runner has been migrated to the AgentRunner plugin format.
|
||||
Use the official `langbot/local-agent` plugin instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import copy
|
||||
import typing
|
||||
from .. import runner
|
||||
from ..modelmgr import requester as modelmgr_requester
|
||||
from ..tools.loaders.native import EXEC_TOOL_NAME
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.rag.context as rag_context
|
||||
|
||||
|
||||
rag_combined_prompt_template = """
|
||||
The following are relevant context entries retrieved from the knowledge base.
|
||||
Please use them to answer the user's message.
|
||||
Respond in the same language as the user's input.
|
||||
|
||||
<context>
|
||||
{rag_context}
|
||||
</context>
|
||||
|
||||
<user_message>
|
||||
{user_message}
|
||||
</user_message>
|
||||
"""
|
||||
|
||||
SANDBOX_EXEC_TOOL_NAME = 'sandbox_exec'
|
||||
SANDBOX_EXEC_SYSTEM_GUIDANCE = (
|
||||
'When sandbox_exec is available, use it for exact calculations, statistics, structured data parsing, '
|
||||
'and code execution instead of estimating mentally. If the user provides numbers, tables, CSV-like text, '
|
||||
'JSON, or other data and asks for a computed answer, prefer running a short Python script in sandbox_exec '
|
||||
'and then answer from the tool result.'
|
||||
)
|
||||
|
||||
|
||||
# Hard cap on tool-call rounds within a single agent turn. A looping or
|
||||
# adversarial model can otherwise emit tool calls indefinitely (each potentially
|
||||
# a sandbox exec), yielding a non-terminating request and runaway cost. Set
|
||||
# generously so it never interrupts legitimate multi-step agentic workflows.
|
||||
MAX_TOOL_CALL_ROUNDS = 128
|
||||
|
||||
|
||||
@runner.runner_class('local-agent')
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""Local agent request runner"""
|
||||
|
||||
def _build_request_messages(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
user_message: provider_message.Message,
|
||||
) -> list[provider_message.Message]:
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy()
|
||||
|
||||
if any(getattr(tool, 'name', None) == EXEC_TOOL_NAME for tool in query.use_funcs or []):
|
||||
req_messages.append(
|
||||
provider_message.Message(
|
||||
role='system',
|
||||
content=self.ap.box_service.get_system_guidance(),
|
||||
)
|
||||
)
|
||||
|
||||
req_messages.append(user_message)
|
||||
return req_messages
|
||||
|
||||
async def _get_model_candidates(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
) -> list[modelmgr_requester.RuntimeLLMModel]:
|
||||
"""Build ordered list of models to try: primary model + fallback models."""
|
||||
candidates = []
|
||||
|
||||
# Primary model
|
||||
if query.use_llm_model_uuid:
|
||||
try:
|
||||
primary = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid)
|
||||
candidates.append(primary)
|
||||
except ValueError:
|
||||
self.ap.logger.warning(f'Primary model {query.use_llm_model_uuid} not found')
|
||||
|
||||
# Fallback models
|
||||
fallback_uuids = (query.variables or {}).get('_fallback_model_uuids', [])
|
||||
for fb_uuid in fallback_uuids:
|
||||
try:
|
||||
fb_model = await self.ap.model_mgr.get_model_by_uuid(fb_uuid)
|
||||
candidates.append(fb_model)
|
||||
except ValueError:
|
||||
self.ap.logger.warning(f'Fallback model {fb_uuid} not found, skipping')
|
||||
|
||||
return candidates
|
||||
|
||||
async def _invoke_with_fallback(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
candidates: list[modelmgr_requester.RuntimeLLMModel],
|
||||
messages: list,
|
||||
funcs: list,
|
||||
remove_think: bool,
|
||||
) -> tuple[provider_message.Message, modelmgr_requester.RuntimeLLMModel]:
|
||||
"""Try non-streaming invocation with sequential fallback. Returns (message, model_used)."""
|
||||
last_error = None
|
||||
for model in candidates:
|
||||
try:
|
||||
msg = await model.provider.invoke_llm(
|
||||
query,
|
||||
model,
|
||||
messages,
|
||||
funcs if model.model_entity.abilities.__contains__('func_call') else [],
|
||||
extra_args=model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
return msg, model
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.ap.logger.warning(f'Model {model.model_entity.name} failed: {e}, trying next fallback...')
|
||||
raise last_error or RuntimeError('No model candidates available')
|
||||
|
||||
async def _invoke_stream_with_fallback(
|
||||
self,
|
||||
query: pipeline_query.Query,
|
||||
candidates: list[modelmgr_requester.RuntimeLLMModel],
|
||||
messages: list,
|
||||
funcs: list,
|
||||
remove_think: bool,
|
||||
) -> tuple[typing.AsyncGenerator, modelmgr_requester.RuntimeLLMModel]:
|
||||
"""Try streaming invocation with sequential fallback. Returns (stream_generator, model_used).
|
||||
|
||||
Fallback is only possible before any chunks have been yielded to the client.
|
||||
Once streaming starts, the model is committed.
|
||||
"""
|
||||
last_error = None
|
||||
for model in candidates:
|
||||
try:
|
||||
stream = model.provider.invoke_llm_stream(
|
||||
query,
|
||||
model,
|
||||
messages,
|
||||
funcs if model.model_entity.abilities.__contains__('func_call') else [],
|
||||
extra_args=model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
# Attempt to get the first chunk to verify the stream works
|
||||
first_chunk = await stream.__anext__()
|
||||
|
||||
async def _chain_stream(first, rest):
|
||||
yield first
|
||||
async for chunk in rest:
|
||||
yield chunk
|
||||
|
||||
return _chain_stream(first_chunk, stream), model
|
||||
except StopAsyncIteration:
|
||||
# Empty stream — treat as success (model returned nothing)
|
||||
async def _empty_stream():
|
||||
return
|
||||
yield # make it a generator
|
||||
|
||||
return _empty_stream(), model
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.ap.logger.warning(f'Model {model.model_entity.name} stream failed: {e}, trying next fallback...')
|
||||
raise last_error or RuntimeError('No model candidates available')
|
||||
|
||||
async def run(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message | provider_message.MessageChunk, None]:
|
||||
"""Run request"""
|
||||
pending_tool_calls = []
|
||||
initial_response_emitted = False
|
||||
|
||||
# Get knowledge bases list from query variables (set by PreProcessor,
|
||||
# may have been modified by plugins during PromptPreProcessing)
|
||||
kb_uuids = query.variables.get('_knowledge_base_uuids', [])
|
||||
|
||||
user_message = copy.deepcopy(query.user_message)
|
||||
|
||||
user_message_text = ''
|
||||
|
||||
if isinstance(user_message.content, str):
|
||||
user_message_text = user_message.content
|
||||
elif isinstance(user_message.content, list):
|
||||
for ce in user_message.content:
|
||||
if ce.type == 'text':
|
||||
user_message_text += ce.text
|
||||
break
|
||||
|
||||
if kb_uuids and user_message_text:
|
||||
# only support text for now
|
||||
all_results: list[rag_context.RetrievalResultEntry] = []
|
||||
|
||||
# Retrieve from each knowledge base
|
||||
for kb_uuid in kb_uuids:
|
||||
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
|
||||
if not kb:
|
||||
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found, skipping')
|
||||
continue
|
||||
|
||||
result = await kb.retrieve(
|
||||
user_message_text,
|
||||
settings={
|
||||
'bot_uuid': query.bot_uuid or '',
|
||||
'sender_id': str(query.sender_id),
|
||||
'session_name': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
},
|
||||
)
|
||||
|
||||
if result:
|
||||
all_results.extend(result)
|
||||
|
||||
# Rerank step: re-score results using a rerank model if configured
|
||||
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||
rerank_model_uuid = local_agent_config.get('rerank-model', '')
|
||||
if rerank_model_uuid == '__none__':
|
||||
rerank_model_uuid = ''
|
||||
self.ap.logger.info(
|
||||
f'Rerank config: model_uuid={rerank_model_uuid!r}, '
|
||||
f'results={len(all_results)}, '
|
||||
f'local_agent_keys={list(local_agent_config.keys())}'
|
||||
)
|
||||
if all_results and rerank_model_uuid:
|
||||
try:
|
||||
rerank_model = await self.ap.model_mgr.get_rerank_model_by_uuid(rerank_model_uuid)
|
||||
rerank_top_k = int(local_agent_config.get('rerank-top-k', 5))
|
||||
|
||||
doc_texts = []
|
||||
for entry in all_results:
|
||||
text = ' '.join(c.text for c in entry.content if c.type == 'text' and c.text)
|
||||
doc_texts.append(text)
|
||||
|
||||
doc_texts_capped = doc_texts[:64]
|
||||
scores = await rerank_model.provider.invoke_rerank(
|
||||
model=rerank_model,
|
||||
query=user_message_text,
|
||||
documents=doc_texts_capped,
|
||||
)
|
||||
|
||||
scored = sorted(scores, key=lambda x: x.get('relevance_score', 0), reverse=True)
|
||||
top_indices = [s['index'] for s in scored[:rerank_top_k] if s['index'] < len(all_results)]
|
||||
all_results = [all_results[i] for i in top_indices]
|
||||
|
||||
self.ap.logger.info(
|
||||
f'Rerank complete: {len(doc_texts)} docs reranked -> top {len(all_results)} kept (top_k={rerank_top_k})'
|
||||
)
|
||||
except ValueError:
|
||||
self.ap.logger.warning(f'Rerank model {rerank_model_uuid} not found, skipping rerank')
|
||||
except Exception as e:
|
||||
self.ap.logger.warning(f'Rerank failed, using original order: {e}')
|
||||
|
||||
final_user_message_text = ''
|
||||
|
||||
if all_results:
|
||||
texts = []
|
||||
idx = 1
|
||||
for entry in all_results:
|
||||
for content in entry.content:
|
||||
if content.type == 'text' and content.text is not None:
|
||||
texts.append(f'[{idx}] {content.text}')
|
||||
idx += 1
|
||||
rag_context_text = '\n\n'.join(texts)
|
||||
final_user_message_text = rag_combined_prompt_template.format(
|
||||
rag_context=rag_context_text, user_message=user_message_text
|
||||
)
|
||||
|
||||
else:
|
||||
final_user_message_text = user_message_text
|
||||
|
||||
self.ap.logger.debug(f'Final user message text: {final_user_message_text}')
|
||||
|
||||
for ce in user_message.content:
|
||||
if ce.type == 'text':
|
||||
ce.text = final_user_message_text
|
||||
break
|
||||
|
||||
req_messages = self._build_request_messages(query, user_message)
|
||||
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
remove_think = query.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
|
||||
# Build ordered candidate list (primary + fallbacks)
|
||||
candidates = await self._get_model_candidates(query)
|
||||
if not candidates:
|
||||
raise RuntimeError('No LLM model configured for local-agent runner')
|
||||
|
||||
self.ap.logger.debug(
|
||||
f'localagent req: query={query.query_id} req_messages={req_messages} '
|
||||
f'candidates={[m.model_entity.name for m in candidates]}'
|
||||
)
|
||||
|
||||
if not is_stream:
|
||||
# Non-streaming: invoke with fallback
|
||||
msg, use_llm_model = await self._invoke_with_fallback(
|
||||
query,
|
||||
candidates,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
remove_think,
|
||||
)
|
||||
final_msg = msg
|
||||
else:
|
||||
# Streaming: invoke with fallback
|
||||
tool_calls_map: dict[str, provider_message.ToolCall] = {}
|
||||
msg_idx = 0
|
||||
accumulated_content = ''
|
||||
last_role = 'assistant'
|
||||
msg_sequence = 1
|
||||
|
||||
stream_src, use_llm_model = await self._invoke_stream_with_fallback(
|
||||
query,
|
||||
candidates,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
remove_think,
|
||||
)
|
||||
async for msg in stream_src:
|
||||
msg_idx = msg_idx + 1
|
||||
|
||||
if msg.role:
|
||||
last_role = msg.role
|
||||
|
||||
if msg.content:
|
||||
accumulated_content += msg.content
|
||||
|
||||
if msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.id not in tool_calls_map:
|
||||
tool_calls_map[tool_call.id] = provider_message.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=provider_message.FunctionCall(
|
||||
name=tool_call.function.name if tool_call.function else '', arguments=''
|
||||
),
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||
|
||||
if msg_idx % 8 == 0 or msg.is_final:
|
||||
msg_sequence += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||
is_final=msg.is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
initial_response_emitted = True
|
||||
|
||||
final_msg = provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if tool_calls_map else None,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
|
||||
pending_tool_calls = final_msg.tool_calls
|
||||
first_content = final_msg.content
|
||||
if isinstance(final_msg, provider_message.MessageChunk):
|
||||
first_end_sequence = final_msg.msg_sequence
|
||||
|
||||
if not is_stream:
|
||||
yield final_msg
|
||||
elif not initial_response_emitted:
|
||||
yield final_msg
|
||||
initial_response_emitted = True
|
||||
|
||||
req_messages.append(final_msg)
|
||||
|
||||
# Once a model succeeds, commit to it for the tool call loop
|
||||
# (no fallback mid-conversation — different models may interpret tool results differently)
|
||||
tool_call_round = 0
|
||||
while pending_tool_calls:
|
||||
tool_call_round += 1
|
||||
if tool_call_round > MAX_TOOL_CALL_ROUNDS:
|
||||
self.ap.logger.warning(
|
||||
f'Tool-call loop reached the {MAX_TOOL_CALL_ROUNDS}-round cap '
|
||||
f'(query_id={query.query_id}); stopping to avoid a non-terminating request.'
|
||||
)
|
||||
break
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
if func.arguments:
|
||||
parameters = json.loads(func.arguments)
|
||||
else:
|
||||
parameters = {}
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query)
|
||||
|
||||
# Handle return value content
|
||||
tool_content = None
|
||||
if (
|
||||
isinstance(func_ret, list)
|
||||
and len(func_ret) > 0
|
||||
and isinstance(func_ret[0], provider_message.ContentElement)
|
||||
):
|
||||
tool_content = func_ret
|
||||
else:
|
||||
tool_content = json.dumps(func_ret, ensure_ascii=False)
|
||||
|
||||
if is_stream:
|
||||
msg = provider_message.MessageChunk(
|
||||
role='tool',
|
||||
content=tool_content,
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
else:
|
||||
msg = provider_message.Message(
|
||||
role='tool',
|
||||
content=tool_content,
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg)
|
||||
except Exception as e:
|
||||
if is_stream:
|
||||
err_msg = provider_message.MessageChunk(
|
||||
role='tool',
|
||||
content=f'err: {e}',
|
||||
tool_call_id=tool_call.id,
|
||||
is_final=True,
|
||||
)
|
||||
else:
|
||||
err_msg = provider_message.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id)
|
||||
|
||||
yield err_msg
|
||||
|
||||
req_messages.append(err_msg)
|
||||
|
||||
self.ap.logger.debug(
|
||||
f'localagent req: query={query.query_id} req_messages={req_messages} '
|
||||
f'use_llm_model={use_llm_model.model_entity.name}'
|
||||
)
|
||||
|
||||
if is_stream:
|
||||
tool_calls_map = {}
|
||||
msg_idx = 0
|
||||
accumulated_content = ''
|
||||
last_role = 'assistant'
|
||||
msg_sequence = first_end_sequence
|
||||
|
||||
tool_stream_src = use_llm_model.provider.invoke_llm_stream(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
|
||||
extra_args=use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
async for msg in tool_stream_src:
|
||||
msg_idx += 1
|
||||
|
||||
if msg.role:
|
||||
last_role = msg.role
|
||||
|
||||
# Prepend first-round content on first chunk of tool-call round
|
||||
if msg_idx == 1:
|
||||
accumulated_content = first_content if first_content is not None else accumulated_content
|
||||
|
||||
if msg.content:
|
||||
accumulated_content += msg.content
|
||||
|
||||
if msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.id not in tool_calls_map:
|
||||
tool_calls_map[tool_call.id] = provider_message.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=provider_message.FunctionCall(
|
||||
name=tool_call.function.name if tool_call.function else '', arguments=''
|
||||
),
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||
|
||||
if msg_idx % 8 == 0 or msg.is_final:
|
||||
msg_sequence += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||
is_final=msg.is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
|
||||
final_msg = provider_message.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if tool_calls_map else None,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
else:
|
||||
# Non-streaming: use committed model directly (no fallback in tool loop)
|
||||
msg = await use_llm_model.provider.invoke_llm(
|
||||
query,
|
||||
use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs if use_llm_model.model_entity.abilities.__contains__('func_call') else [],
|
||||
extra_args=use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
|
||||
yield msg
|
||||
final_msg = msg
|
||||
|
||||
pending_tool_calls = final_msg.tool_calls
|
||||
|
||||
req_messages.append(final_msg)
|
||||
@@ -1,284 +0,0 @@
|
||||
"""
|
||||
Legacy n8n Service API Runner.
|
||||
|
||||
DEPRECATED: This runner has been migrated to the AgentRunner plugin format.
|
||||
Use the official `langbot/n8n-agent` plugin instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import uuid
|
||||
import aiohttp
|
||||
|
||||
from langbot.pkg.utils import httpclient
|
||||
|
||||
from .. import runner
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
class N8nAPIError(Exception):
|
||||
"""N8n API 请求失败"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@runner.runner_class('n8n-service-api')
|
||||
class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
"""N8n Service API 工作流请求器"""
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
# 获取webhook URL
|
||||
self.webhook_url = self.pipeline_config['ai']['n8n-service-api']['webhook-url']
|
||||
|
||||
# 获取超时设置,默认为120秒
|
||||
self.timeout = self.pipeline_config['ai']['n8n-service-api'].get('timeout', 120)
|
||||
|
||||
# 获取输出键名,默认为response
|
||||
self.output_key = self.pipeline_config['ai']['n8n-service-api'].get('output-key', 'response')
|
||||
|
||||
# 获取认证类型,默认为none
|
||||
self.auth_type = self.pipeline_config['ai']['n8n-service-api'].get('auth-type', 'none')
|
||||
|
||||
# 根据认证类型获取相应的认证信息
|
||||
if self.auth_type == 'basic':
|
||||
self.basic_username = self.pipeline_config['ai']['n8n-service-api'].get('basic-username', '')
|
||||
self.basic_password = self.pipeline_config['ai']['n8n-service-api'].get('basic-password', '')
|
||||
elif self.auth_type == 'jwt':
|
||||
self.jwt_secret = self.pipeline_config['ai']['n8n-service-api'].get('jwt-secret', '')
|
||||
self.jwt_algorithm = self.pipeline_config['ai']['n8n-service-api'].get('jwt-algorithm', 'HS256')
|
||||
elif self.auth_type == 'header':
|
||||
self.header_name = self.pipeline_config['ai']['n8n-service-api'].get('header-name', '')
|
||||
self.header_value = self.pipeline_config['ai']['n8n-service-api'].get('header-value', '')
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> str:
|
||||
"""预处理用户消息,提取纯文本
|
||||
|
||||
Returns:
|
||||
str: 纯文本消息
|
||||
"""
|
||||
plain_text = ''
|
||||
|
||||
if isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
# 注意:n8n webhook目前不支持直接处理图片,如需支持可在此扩展
|
||||
elif isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
|
||||
return plain_text
|
||||
|
||||
async def _process_response(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""处理响应——支持流式格式和普通 JSON 格式"""
|
||||
full_content = ''
|
||||
full_text = ''
|
||||
chunk_idx = 0
|
||||
is_final = False
|
||||
message_idx = 0
|
||||
|
||||
buffer = ''
|
||||
decoder = json.JSONDecoder()
|
||||
|
||||
async for raw_chunk in response.content.iter_chunked(1024):
|
||||
if not raw_chunk:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 将 bytes 解码为字符串(容忍错误)
|
||||
if isinstance(raw_chunk, (bytes, bytearray)):
|
||||
chunk_str = raw_chunk.decode('utf-8', errors='replace')
|
||||
else:
|
||||
chunk_str = str(raw_chunk)
|
||||
|
||||
full_text += chunk_str
|
||||
buffer += chunk_str
|
||||
|
||||
# 尝试从 buffer 中循环解析出 JSON 对象(处理多个对象或部分对象)
|
||||
while buffer:
|
||||
buffer = buffer.lstrip()
|
||||
if not buffer:
|
||||
break
|
||||
try:
|
||||
obj, idx = decoder.raw_decode(buffer)
|
||||
buffer = buffer[idx:]
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
# 忽略非字典类型的顶级 JSON
|
||||
continue
|
||||
|
||||
if obj.get('type') == 'item' and 'content' in obj:
|
||||
chunk_idx += 1
|
||||
content = obj['content']
|
||||
full_content += content
|
||||
elif obj.get('type') == 'end':
|
||||
is_final = True
|
||||
|
||||
if is_final or (chunk_idx > 0 and chunk_idx % 8 == 0):
|
||||
message_idx += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=full_content,
|
||||
is_final=is_final,
|
||||
msg_sequence=message_idx,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# buffer 末尾可能是一个不完整的 JSON,等待更多数据
|
||||
break
|
||||
except Exception as e:
|
||||
# 记录解析失败并继续接收后续 chunk
|
||||
try:
|
||||
preview = chunk_str[:200]
|
||||
except Exception:
|
||||
preview = '<unavailable>'
|
||||
self.ap.logger.warning(f'Failed to process chunk: {e}; chunk preview: {preview}')
|
||||
|
||||
# 流结束后,尝试解析残余 buffer
|
||||
if buffer:
|
||||
try:
|
||||
buffer = buffer.strip()
|
||||
if buffer:
|
||||
obj, _ = decoder.raw_decode(buffer)
|
||||
if isinstance(obj, dict):
|
||||
if obj.get('type') == 'item' and 'content' in obj:
|
||||
chunk_idx += 1
|
||||
full_content += obj['content']
|
||||
elif obj.get('type') == 'end':
|
||||
is_final = True
|
||||
message_idx += 1
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=full_content,
|
||||
is_final=is_final,
|
||||
msg_sequence=message_idx,
|
||||
)
|
||||
except Exception as e:
|
||||
preview = buffer[:200]
|
||||
self.ap.logger.warning(f'Failed to parse remaining buffer: {e}; buffer preview: {preview}')
|
||||
|
||||
# n8n 返回普通 JSON 格式(无任何流式 type:item 内容)
|
||||
if chunk_idx == 0:
|
||||
output_content = ''
|
||||
try:
|
||||
response_data = json.loads(full_text.strip())
|
||||
if isinstance(response_data, dict):
|
||||
if self.output_key in response_data:
|
||||
output_content = response_data[self.output_key]
|
||||
else:
|
||||
output_content = json.dumps(response_data, ensure_ascii=False)
|
||||
else:
|
||||
output_content = full_text
|
||||
except json.JSONDecodeError:
|
||||
output_content = full_text
|
||||
self.ap.logger.debug(f'n8n webhook response (non-stream): {full_text[:200]}')
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=output_content,
|
||||
is_final=True,
|
||||
msg_sequence=message_idx + 1,
|
||||
)
|
||||
|
||||
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用n8n webhook"""
|
||||
# 生成会话ID(如果不存在)
|
||||
if not query.session.using_conversation.uuid:
|
||||
query.session.using_conversation.uuid = str(uuid.uuid4())
|
||||
|
||||
# Keep query variables in sync with the generated/new conversation id.
|
||||
# query.variables is later merged into payload and would otherwise
|
||||
# overwrite the generated conversation_id with the stale preprocessor
|
||||
# value (usually None for a new conversation).
|
||||
query.variables['conversation_id'] = query.session.using_conversation.uuid
|
||||
|
||||
# 预处理用户消息
|
||||
plain_text = await self._preprocess_user_message(query)
|
||||
|
||||
# 准备请求数据
|
||||
payload = {
|
||||
# 基本消息内容
|
||||
'chatInput': plain_text, # 考虑到之前用户直接用的message model这里添加新键
|
||||
'message': plain_text,
|
||||
'user_message_text': plain_text,
|
||||
'conversation_id': query.session.using_conversation.uuid,
|
||||
'session_id': query.variables.get('session_id', ''),
|
||||
'user_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
'msg_create_time': query.variables.get('msg_create_time', ''),
|
||||
}
|
||||
|
||||
# 添加所有变量到payload
|
||||
payload.update(query.variables)
|
||||
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
try:
|
||||
# 准备请求头和认证信息
|
||||
headers = {}
|
||||
auth = None
|
||||
|
||||
# 根据认证类型设置相应的认证信息
|
||||
if self.auth_type == 'basic':
|
||||
# 使用Basic认证
|
||||
auth = aiohttp.BasicAuth(self.basic_username, self.basic_password)
|
||||
self.ap.logger.debug(f'using basic auth: {self.basic_username}')
|
||||
elif self.auth_type == 'jwt':
|
||||
# 使用JWT认证
|
||||
import jwt
|
||||
import time
|
||||
|
||||
# 创建JWT令牌
|
||||
payload_jwt = {
|
||||
'exp': int(time.time()) + 3600, # 1小时过期
|
||||
'iat': int(time.time()),
|
||||
'sub': 'n8n-webhook',
|
||||
}
|
||||
token = jwt.encode(payload_jwt, self.jwt_secret, algorithm=self.jwt_algorithm)
|
||||
|
||||
# 添加到Authorization头
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
self.ap.logger.debug('using jwt auth')
|
||||
elif self.auth_type == 'header':
|
||||
# 使用自定义请求头认证
|
||||
headers[self.header_name] = self.header_value
|
||||
self.ap.logger.debug(f'using header auth: {self.header_name}')
|
||||
else:
|
||||
self.ap.logger.debug('no auth')
|
||||
|
||||
# 调用webhook
|
||||
session = httpclient.get_session()
|
||||
async with session.post(
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
|
||||
async for chunk in self._process_response(response):
|
||||
if is_stream:
|
||||
yield chunk
|
||||
elif chunk.is_final:
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=chunk.content,
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
|
||||
raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行请求"""
|
||||
async for msg in self._call_webhook(query):
|
||||
yield msg
|
||||
@@ -1,209 +0,0 @@
|
||||
"""
|
||||
Legacy Tbox (蚂蚁百宝箱) API Runner.
|
||||
|
||||
DEPRECATED: This runner has been migrated to the AgentRunner plugin format.
|
||||
Use the official `langbot/tbox-agent` plugin instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import base64
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from tboxsdk.tbox import TboxClient
|
||||
from tboxsdk.model.file import File, FileType
|
||||
|
||||
from .. import runner
|
||||
from ...core import app
|
||||
from ...utils import image
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
class TboxAPIError(Exception):
|
||||
"""TBox API 请求失败"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@runner.runner_class('tbox-app-api')
|
||||
class TboxAPIRunner(runner.RequestRunner):
|
||||
"蚂蚁百宝箱API对话请求器"
|
||||
|
||||
# 运行器内部使用的配置
|
||||
app_id: str # 蚂蚁百宝箱平台中的应用ID
|
||||
api_key: str # 在蚂蚁百宝箱平台中申请的令牌
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
"""初始化"""
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
# 初始化Tbox 参数配置
|
||||
self.app_id = self.pipeline_config['ai']['tbox-app-api']['app-id']
|
||||
self.api_key = self.pipeline_config['ai']['tbox-app-api']['api-key']
|
||||
|
||||
# 初始化Tbox client
|
||||
self.tbox_client = TboxClient(authorization=self.api_key)
|
||||
|
||||
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
|
||||
"""预处理用户消息,提取纯文本,并将图片上传到 Tbox 服务
|
||||
|
||||
Returns:
|
||||
tuple[str, list[str]]: 纯文本和图片的 Tbox 文件ID
|
||||
"""
|
||||
plain_text = ''
|
||||
image_ids = []
|
||||
|
||||
if isinstance(query.user_message.content, list):
|
||||
for ce in query.user_message.content:
|
||||
if ce.type == 'text':
|
||||
plain_text += ce.text
|
||||
elif ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
# 创建临时文件
|
||||
file_bytes = base64.b64decode(image_b64)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=f'.{image_format}', delete=False) as tmp_file:
|
||||
tmp_file.write(file_bytes)
|
||||
tmp_file_path = tmp_file.name
|
||||
file_upload_resp = self.tbox_client.upload_file(tmp_file_path)
|
||||
image_id = file_upload_resp.get('data', '')
|
||||
image_ids.append(image_id)
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_file_path):
|
||||
os.unlink(tmp_file_path)
|
||||
elif isinstance(query.user_message.content, str):
|
||||
plain_text = query.user_message.content
|
||||
|
||||
return plain_text, image_ids
|
||||
|
||||
async def _agent_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""TBox 智能体对话请求"""
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
remove_think = self.pipeline_config['output'].get('misc', {}).get('remove-think')
|
||||
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
# 获取Tbox的conversation_id
|
||||
conversation_id = query.session.using_conversation.uuid or None
|
||||
|
||||
files = None
|
||||
if image_ids:
|
||||
files = [File(file_id=image_id, type=FileType.IMAGE) for image_id in image_ids]
|
||||
|
||||
# 发送对话请求
|
||||
response = self.tbox_client.chat(
|
||||
app_id=self.app_id, # Tbox中智能体应用的ID
|
||||
user_id=query.bot_uuid, # 用户ID
|
||||
query=plain_text, # 用户输入的文本信息
|
||||
stream=is_stream, # 是否流式输出
|
||||
conversation_id=conversation_id, # 会话ID,为None时Tbox会自动创建一个新会话
|
||||
files=files, # 图片内容
|
||||
)
|
||||
|
||||
if is_stream:
|
||||
# 解析Tbox流式输出内容,并发送给上游
|
||||
for chunk in self._process_stream_message(response, query, remove_think):
|
||||
yield chunk
|
||||
else:
|
||||
message = self._process_non_stream_message(response, query, remove_think)
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=message,
|
||||
)
|
||||
|
||||
def _process_non_stream_message(self, response: typing.Dict, query: pipeline_query.Query, remove_think: bool):
|
||||
if response.get('errorCode') != '0':
|
||||
raise TboxAPIError(f'Tbox API 请求失败: {response.get("errorMsg", "")}')
|
||||
payload = response.get('data', {})
|
||||
conversation_id = payload.get('conversationId', '')
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
thinking_content = payload.get('reasoningContent', [])
|
||||
result = ''
|
||||
if thinking_content and not remove_think:
|
||||
result += f'<think>\n{thinking_content[0].get("text", "")}\n</think>\n'
|
||||
content = payload.get('result', [])
|
||||
if content:
|
||||
result += content[0].get('chunk', '')
|
||||
return result
|
||||
|
||||
def _process_stream_message(
|
||||
self, response: typing.Generator[dict], query: pipeline_query.Query, remove_think: bool
|
||||
):
|
||||
idx_msg = 0
|
||||
pending_content = ''
|
||||
conversation_id = None
|
||||
think_start = False
|
||||
think_end = False
|
||||
for chunk in response:
|
||||
if chunk.get('type', '') == 'chunk':
|
||||
"""
|
||||
Tbox返回的消息内容chunk结构
|
||||
{'lane': 'default', 'payload': {'conversationId': '20250918tBI947065406', 'messageId': '20250918TB1f53230954', 'text': '️'}, 'type': 'chunk'}
|
||||
"""
|
||||
# 如果包含思考过程,拼接</think>
|
||||
if think_start and not think_end:
|
||||
pending_content += '\n</think>\n'
|
||||
think_end = True
|
||||
|
||||
payload = chunk.get('payload', {})
|
||||
if not conversation_id:
|
||||
conversation_id = payload.get('conversationId')
|
||||
query.session.using_conversation.uuid = conversation_id
|
||||
if payload.get('text'):
|
||||
idx_msg += 1
|
||||
pending_content += payload.get('text')
|
||||
elif chunk.get('type', '') == 'thinking' and not remove_think:
|
||||
"""
|
||||
Tbox返回的思考过程chunk结构
|
||||
{'payload': '{"ext_data":{"text":"日期"},"event":"flow.node.llm.thinking","entity":{"node_type":"text-completion","execute_id":"6","group_id":0,"parent_execute_id":"6","node_name":"模型推理","node_id":"TC_5u6gl0"}}', 'type': 'thinking'}
|
||||
"""
|
||||
payload = json.loads(chunk.get('payload', '{}'))
|
||||
if payload.get('ext_data', {}).get('text'):
|
||||
idx_msg += 1
|
||||
content = payload.get('ext_data', {}).get('text')
|
||||
if not think_start:
|
||||
think_start = True
|
||||
pending_content += f'<think>\n{content}'
|
||||
else:
|
||||
pending_content += content
|
||||
elif chunk.get('type', '') == 'error':
|
||||
raise TboxAPIError(
|
||||
f'Tbox API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
|
||||
if idx_msg % 8 == 0:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Tbox不返回END事件,默认发一个最终消息
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行"""
|
||||
msg_seq = 0
|
||||
async for msg in self._agent_messages(query):
|
||||
if isinstance(msg, provider_message.MessageChunk):
|
||||
msg_seq += 1
|
||||
msg.msg_sequence = msg_seq
|
||||
yield msg
|
||||
@@ -31,7 +31,7 @@ def mock_circular_import_chain():
|
||||
"""
|
||||
Break circular import chain for pipeline modules using isolated_sys_modules.
|
||||
|
||||
Chain: pipeline → core.app → provider.runner → http_controller → groups/plugins
|
||||
Chain: pipeline → core.app → http_controller → groups/plugins
|
||||
|
||||
We mock minimal modules to allow importing RuntimePipeline, StageInstContainer,
|
||||
and stage classes without triggering full application initialization.
|
||||
@@ -63,14 +63,12 @@ def mock_circular_import_chain():
|
||||
'langbot.pkg.pipeline.process.handlers.chat',
|
||||
'langbot.pkg.pipeline.process.handlers.command',
|
||||
'langbot.pkg.pipeline.respback.respback',
|
||||
'langbot.pkg.provider.runner',
|
||||
]
|
||||
|
||||
with isolated_sys_modules(
|
||||
mocks={
|
||||
'langbot.pkg.core.entities': mock_core_entities,
|
||||
'langbot.pkg.core.app': mock_core_app,
|
||||
'langbot.pkg.provider.runner': Mock(preregistered_runners=[]),
|
||||
'langbot.pkg.utils.importutil': mock_importutil,
|
||||
'langbot.pkg.pipeline.controller': Mock(),
|
||||
'langbot.pkg.pipeline.pipelinemgr': Mock(),
|
||||
@@ -342,7 +340,7 @@ class TestPreProcessorStage:
|
||||
|
||||
result = await preproc_stage.process(query, 'PreProcessor')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.result_type.name == entities.ResultType.CONTINUE.name
|
||||
assert result.new_query.session is not None
|
||||
assert result.new_query.user_message is not None
|
||||
|
||||
@@ -369,7 +367,7 @@ class TestPreProcessorStage:
|
||||
|
||||
result = await preproc_stage.process(query, 'PreProcessor')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.result_type.name == entities.ResultType.CONTINUE.name
|
||||
# Check user_message content
|
||||
assert result.new_query.user_message is not None
|
||||
assert result.new_query.user_message.role == 'user'
|
||||
@@ -440,7 +438,7 @@ class TestProcessorStage:
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert results[0].result_type.name == entities.ResultType.INTERRUPT.name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_prevent_default_with_reply_continues(self, pipeline_app, fake_platform_adapter):
|
||||
@@ -474,7 +472,7 @@ class TestProcessorStage:
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.CONTINUE
|
||||
assert results[0].result_type.name == entities.ResultType.CONTINUE.name
|
||||
assert len(query.resp_messages) == 1
|
||||
assert query.resp_messages[0] == reply_chain
|
||||
|
||||
@@ -518,7 +516,7 @@ class TestRunnerExceptionFlow:
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert results[0].result_type.name == entities.ResultType.INTERRUPT.name
|
||||
assert results[0].user_notice == 'Request failed.'
|
||||
assert results[0].error_notice is not None
|
||||
|
||||
@@ -556,7 +554,7 @@ class TestRunnerExceptionFlow:
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert results[0].result_type.name == entities.ResultType.INTERRUPT.name
|
||||
assert 'Custom runtime error' in results[0].user_notice
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -593,7 +591,7 @@ class TestRunnerExceptionFlow:
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert results[0].result_type.name == entities.ResultType.INTERRUPT.name
|
||||
assert results[0].user_notice is None
|
||||
|
||||
|
||||
@@ -625,7 +623,7 @@ class TestSendResponseBackStage:
|
||||
|
||||
result = await respback_stage.process(query, 'SendResponseBackStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.result_type.name == entities.ResultType.CONTINUE.name
|
||||
|
||||
# Check that adapter was called
|
||||
outbound = platform.get_outbound_messages()
|
||||
@@ -691,7 +689,7 @@ class TestStageChainIntegration:
|
||||
|
||||
# Run PreProcessor
|
||||
result1 = await preproc_stage.process(query, 'PreProcessor')
|
||||
assert result1.result_type == entities.ResultType.CONTINUE
|
||||
assert result1.result_type.name == entities.ResultType.CONTINUE.name
|
||||
query = result1.new_query
|
||||
|
||||
# Run Processor
|
||||
@@ -705,7 +703,7 @@ class TestStageChainIntegration:
|
||||
|
||||
# Run SendResponseBackStage
|
||||
result3 = await respback_stage.process(query, 'SendResponseBackStage')
|
||||
assert result3.result_type == entities.ResultType.CONTINUE
|
||||
assert result3.result_type.name == entities.ResultType.CONTINUE.name
|
||||
|
||||
# Verify adapter was called
|
||||
outbound = platform.get_outbound_messages()
|
||||
@@ -753,14 +751,14 @@ class TestStageChainIntegration:
|
||||
|
||||
# Run PreProcessor
|
||||
result1 = await preproc_stage.process(query, 'PreProcessor')
|
||||
assert result1.result_type == entities.ResultType.CONTINUE
|
||||
assert result1.result_type.name == entities.ResultType.CONTINUE.name
|
||||
query = result1.new_query
|
||||
|
||||
# Run Processor - should INTERRUPT
|
||||
results = await collect_processor_results(processor_stage, query, 'MessageProcessor')
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].result_type == entities.ResultType.INTERRUPT
|
||||
assert results[0].result_type.name == entities.ResultType.INTERRUPT.name
|
||||
|
||||
# Chain stops here - no resp_messages
|
||||
assert len(query.resp_messages) == 0
|
||||
|
||||
@@ -43,7 +43,7 @@ def make_session(
|
||||
plugin_identity: str = 'test/test-runner',
|
||||
resources: dict | None = None,
|
||||
conversation_id: str | None = None,
|
||||
permissions: dict[str, list[str]] | None = None,
|
||||
available_apis: dict[str, bool] | None = None,
|
||||
state_policy: dict[str, typing.Any] | None = None,
|
||||
state_context: dict[str, typing.Any] | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
@@ -62,7 +62,7 @@ def make_session(
|
||||
import time
|
||||
now = int(time.time())
|
||||
res = resources if resources is not None else make_resources()
|
||||
perms = permissions if permissions is not None else {}
|
||||
apis = available_apis if available_apis is not None else {}
|
||||
policy = (
|
||||
state_policy
|
||||
if state_policy is not None
|
||||
@@ -85,7 +85,7 @@ def make_session(
|
||||
'plugin_identity': plugin_identity,
|
||||
'authorization': {
|
||||
'resources': res,
|
||||
'permissions': perms,
|
||||
'available_apis': apis,
|
||||
'conversation_id': conversation_id,
|
||||
'state_policy': policy,
|
||||
'state_context': context,
|
||||
|
||||
@@ -212,7 +212,7 @@ class TestArtifactAccessValidation:
|
||||
return make_session(
|
||||
run_id="run_001",
|
||||
conversation_id=conversation_id,
|
||||
permissions={"artifacts": ["metadata", "read"]},
|
||||
available_apis={"artifact_metadata": True, "artifact_read": True},
|
||||
)
|
||||
|
||||
def _call_validate(self, session, metadata, operation="metadata"):
|
||||
@@ -298,33 +298,23 @@ class TestArtifactAccessValidation:
|
||||
|
||||
|
||||
class TestContextAccessArtifactAPIs:
|
||||
"""Test ContextAccess reflects artifact API permissions."""
|
||||
"""Test ContextAccess reflects runtime artifact API availability."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_has_artifact_apis_when_permitted(self):
|
||||
"""Test ContextAccess shows artifact APIs when permissions allow."""
|
||||
# This tests the context builder logic
|
||||
# When artifact permissions include 'metadata' and 'read',
|
||||
# available_apis should reflect that
|
||||
permissions = {"artifacts": ["metadata", "read"]}
|
||||
"""Artifact APIs are exposed through run-scoped available_apis."""
|
||||
available_apis = {"artifact_metadata": True, "artifact_read": True}
|
||||
|
||||
# Check that permissions are properly interpreted
|
||||
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
|
||||
artifact_read_enabled = "read" in permissions.get("artifacts", [])
|
||||
|
||||
assert artifact_metadata_enabled is True
|
||||
assert artifact_read_enabled is True
|
||||
assert available_apis["artifact_metadata"] is True
|
||||
assert available_apis["artifact_read"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_access_no_artifact_apis_without_permission(self):
|
||||
"""Test ContextAccess hides artifact APIs when permissions denied."""
|
||||
permissions = {"artifacts": []}
|
||||
"""Artifact APIs are absent when the run did not receive them."""
|
||||
available_apis = {}
|
||||
|
||||
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
|
||||
artifact_read_enabled = "read" in permissions.get("artifacts", [])
|
||||
|
||||
assert artifact_metadata_enabled is False
|
||||
assert artifact_read_enabled is False
|
||||
assert available_apis.get("artifact_metadata", False) is False
|
||||
assert available_apis.get("artifact_read", False) is False
|
||||
|
||||
|
||||
class TestArtifactMetadataFieldAlignment:
|
||||
@@ -376,8 +366,8 @@ class TestArtifactMetadataFieldAlignment:
|
||||
assert "storage_type" not in result
|
||||
|
||||
|
||||
class TestSessionRegistryPermissions:
|
||||
"""Test that session registry stores and retrieves permissions correctly."""
|
||||
class TestSessionRegistryAvailableAPIs:
|
||||
"""Test that session registry stores and retrieves available APIs correctly."""
|
||||
|
||||
@pytest.fixture
|
||||
def session_registry(self):
|
||||
@@ -387,8 +377,8 @@ class TestSessionRegistryPermissions:
|
||||
return get_session_registry()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_stores_permissions(self, session_registry):
|
||||
"""Test that register() stores permissions from descriptor."""
|
||||
async def test_register_stores_available_apis(self, session_registry):
|
||||
"""Test that register() stores runtime API availability."""
|
||||
await session_registry.register(
|
||||
run_id="run_001",
|
||||
runner_id="plugin:author/plugin/runner",
|
||||
@@ -402,24 +392,26 @@ class TestSessionRegistryPermissions:
|
||||
"storage": {"plugin_storage": True, "workspace_storage": False},
|
||||
"platform_capabilities": {},
|
||||
},
|
||||
permissions={
|
||||
"artifacts": ["metadata", "read"],
|
||||
"history": ["page"],
|
||||
"events": ["get"],
|
||||
available_apis={
|
||||
"artifact_metadata": True,
|
||||
"artifact_read": True,
|
||||
"history_page": True,
|
||||
"event_get": True,
|
||||
},
|
||||
conversation_id="conv_001",
|
||||
)
|
||||
|
||||
session = await session_registry.get("run_001")
|
||||
assert session is not None
|
||||
permissions = session["authorization"]["permissions"]
|
||||
assert permissions["artifacts"] == ["metadata", "read"]
|
||||
assert permissions["history"] == ["page"]
|
||||
assert permissions["events"] == ["get"]
|
||||
available_apis = session["authorization"]["available_apis"]
|
||||
assert available_apis["artifact_metadata"] is True
|
||||
assert available_apis["artifact_read"] is True
|
||||
assert available_apis["history_page"] is True
|
||||
assert available_apis["event_get"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_empty_permissions(self, session_registry):
|
||||
"""Test that register() handles empty permissions."""
|
||||
async def test_register_with_empty_available_apis(self, session_registry):
|
||||
"""Test that register() handles empty API availability."""
|
||||
await session_registry.register(
|
||||
run_id="run_002",
|
||||
runner_id="plugin:author/plugin/runner",
|
||||
@@ -433,13 +425,13 @@ class TestSessionRegistryPermissions:
|
||||
"storage": {"plugin_storage": True, "workspace_storage": False},
|
||||
"platform_capabilities": {},
|
||||
},
|
||||
permissions={},
|
||||
available_apis={},
|
||||
conversation_id="conv_001",
|
||||
)
|
||||
|
||||
session = await session_registry.get("run_002")
|
||||
assert session is not None
|
||||
assert session["authorization"]["permissions"] == {}
|
||||
assert session["authorization"]["available_apis"] == {}
|
||||
|
||||
|
||||
class TestArtifactStoreRealSQLite:
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.agent.runner.context_builder import AgentRunContextBuilder
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope, StatePolicy
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
|
||||
@@ -25,6 +26,27 @@ class MockApplication:
|
||||
self.persistence_mgr.get_db_engine = MagicMock()
|
||||
|
||||
|
||||
def make_descriptor(
|
||||
permissions: dict | None = None,
|
||||
) -> AgentRunnerDescriptor:
|
||||
return AgentRunnerDescriptor(
|
||||
id='plugin:test/runner/default',
|
||||
source='plugin',
|
||||
label={'en_US': 'Test Runner'},
|
||||
plugin_author='test',
|
||||
plugin_name='runner',
|
||||
runner_name='default',
|
||||
permissions=permissions
|
||||
if permissions is not None
|
||||
else {
|
||||
'history': ['page', 'search'],
|
||||
'events': ['get', 'page'],
|
||||
'artifacts': ['metadata', 'read'],
|
||||
'storage': ['plugin'],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestContextAccessStateDetermination:
|
||||
"""Tests for ContextAccess.state field determination - real calls to _build_context_access."""
|
||||
|
||||
@@ -54,10 +76,7 @@ class TestContextAccessStateDetermination:
|
||||
@pytest.fixture
|
||||
def mock_descriptor(self):
|
||||
"""Create mock runner descriptor."""
|
||||
descriptor = MagicMock()
|
||||
descriptor.id = 'plugin:test/runner/default'
|
||||
descriptor.permissions = {}
|
||||
return descriptor
|
||||
return make_descriptor()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_state_true_with_scopes_sets_state_true(self, mock_app, mock_event, mock_descriptor):
|
||||
@@ -237,7 +256,7 @@ class TestBindingWithStatePolicy:
|
||||
|
||||
|
||||
class TestContextAccessOtherAPIs:
|
||||
"""Tests for other available_apis fields based on permissions."""
|
||||
"""Tests for other available_apis fields based on run scope."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
@@ -245,16 +264,12 @@ class TestContextAccessOtherAPIs:
|
||||
return MockApplication()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_apis_based_on_permissions(self, mock_app):
|
||||
"""History APIs availability based on runner permissions."""
|
||||
async def test_history_apis_enabled_with_conversation(self, mock_app):
|
||||
"""History APIs are available when the run has a conversation scope."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.permissions = {
|
||||
'history': ['page', 'search'],
|
||||
}
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
@@ -268,21 +283,16 @@ class TestContextAccessOtherAPIs:
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# History APIs enabled based on permissions
|
||||
assert context_access['available_apis']['history_page'] is True
|
||||
assert context_access['available_apis']['history_search'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_apis_based_on_permissions(self, mock_app):
|
||||
"""Event APIs availability based on runner permissions."""
|
||||
async def test_event_apis_enabled_by_default(self, mock_app):
|
||||
"""Event APIs are available based on current run scope."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.permissions = {
|
||||
'events': ['get', 'page'],
|
||||
}
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
@@ -296,21 +306,16 @@ class TestContextAccessOtherAPIs:
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Event APIs enabled based on permissions
|
||||
assert context_access['available_apis']['event_get'] is True
|
||||
assert context_access['available_apis']['event_page'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_artifact_apis_based_on_permissions(self, mock_app):
|
||||
"""Artifact APIs availability based on runner permissions."""
|
||||
async def test_artifact_apis_enabled_by_default(self, mock_app):
|
||||
"""Artifact APIs are available based on current run scope."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.permissions = {
|
||||
'artifacts': ['metadata', 'read'],
|
||||
}
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
@@ -324,19 +329,16 @@ class TestContextAccessOtherAPIs:
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# Artifact APIs enabled based on permissions
|
||||
assert context_access['available_apis']['artifact_metadata'] is True
|
||||
assert context_access['available_apis']['artifact_read'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_permissions_all_apis_disabled(self, mock_app):
|
||||
"""All pull APIs disabled when permissions are empty."""
|
||||
async def test_conversation_required_apis_disabled_without_conversation(self, mock_app):
|
||||
"""Conversation-scoped APIs are disabled when the run has no conversation."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.conversation_id = None
|
||||
mock_event.thread_id = None
|
||||
|
||||
mock_descriptor = MagicMock()
|
||||
mock_descriptor.permissions = {} # No permissions
|
||||
mock_descriptor = make_descriptor()
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
@@ -350,11 +352,37 @@ class TestContextAccessOtherAPIs:
|
||||
# Real call
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
# All pull APIs should be disabled
|
||||
assert context_access['available_apis']['history_page'] is False
|
||||
assert context_access['available_apis']['history_search'] is False
|
||||
assert context_access['available_apis']['event_get'] is True
|
||||
assert context_access['available_apis']['event_page'] is False
|
||||
assert context_access['available_apis']['artifact_metadata'] is True
|
||||
assert context_access['available_apis']['artifact_read'] is True
|
||||
assert context_access['available_apis']['state'] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manifest_permissions_disable_context_apis(self, mock_app):
|
||||
"""Pull APIs are disabled when manifest permissions omit them."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.conversation_id = 'conv_001'
|
||||
mock_event.thread_id = None
|
||||
mock_descriptor = make_descriptor(permissions={})
|
||||
|
||||
binding = AgentBinding(
|
||||
binding_id='binding_001',
|
||||
runner_id='plugin:test/runner/default',
|
||||
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
|
||||
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
|
||||
)
|
||||
|
||||
builder = AgentRunContextBuilder(mock_app)
|
||||
|
||||
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
|
||||
|
||||
assert context_access['available_apis']['history_page'] is False
|
||||
assert context_access['available_apis']['history_search'] is False
|
||||
assert context_access['available_apis']['event_get'] is False
|
||||
assert context_access['available_apis']['event_page'] is False
|
||||
assert context_access['available_apis']['artifact_metadata'] is False
|
||||
assert context_access['available_apis']['artifact_read'] is False
|
||||
assert context_access['available_apis']['state'] is False
|
||||
assert context_access['available_apis']['storage'] is False
|
||||
|
||||
@@ -18,6 +18,7 @@ from langbot.pkg.agent.runner.context_builder import (
|
||||
AgentRunContextBuilder,
|
||||
AgentResources as BuilderResources,
|
||||
)
|
||||
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
|
||||
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope
|
||||
from langbot.pkg.core import app
|
||||
|
||||
@@ -88,13 +89,20 @@ class TestContextValidation:
|
||||
|
||||
def _make_descriptor(self):
|
||||
"""Create a mock runner descriptor."""
|
||||
descriptor = MagicMock()
|
||||
descriptor.id = "plugin:test/plugin/runner"
|
||||
descriptor.permissions = {
|
||||
'history': ['page', 'search'],
|
||||
'events': ['get', 'page'],
|
||||
}
|
||||
return descriptor
|
||||
return AgentRunnerDescriptor(
|
||||
id="plugin:test/plugin/runner",
|
||||
source="plugin",
|
||||
label={"en_US": "Test Runner"},
|
||||
plugin_author="test",
|
||||
plugin_name="plugin",
|
||||
runner_name="runner",
|
||||
permissions={
|
||||
"history": ["page", "search"],
|
||||
"events": ["get", "page"],
|
||||
"artifacts": ["metadata", "read"],
|
||||
"storage": ["plugin", "workspace"],
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_context_from_event_validates(self):
|
||||
|
||||
@@ -23,12 +23,6 @@ from langbot_plugin.api.entities.builtin.agent_runner.result import (
|
||||
AgentRunResult,
|
||||
AgentRunResultType,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.capabilities import (
|
||||
AgentRunnerCapabilities,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.permissions import (
|
||||
AgentRunnerPermissions,
|
||||
)
|
||||
|
||||
# Import LangBot host models
|
||||
from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter
|
||||
@@ -57,6 +51,7 @@ class TestQueryToEventEnvelope:
|
||||
|
||||
assert event.input is not None
|
||||
assert event.input.text == "Hello world"
|
||||
assert "message_chain" not in event.input.model_dump()
|
||||
|
||||
def test_query_to_event_conversation(self, mock_query):
|
||||
"""Test conversation context extraction."""
|
||||
@@ -232,43 +227,6 @@ class TestHostManagedHistoryNotInProtocol:
|
||||
assert "messages" not in ctx_fields
|
||||
|
||||
|
||||
class TestSDKCapabilitiesProtocolV1:
|
||||
"""Test SDK capabilities for Protocol v1."""
|
||||
|
||||
def test_self_managed_context_default_true(self):
|
||||
"""Test self_managed_context defaults to True for Protocol v1."""
|
||||
caps = AgentRunnerCapabilities()
|
||||
|
||||
assert caps.self_managed_context is True
|
||||
|
||||
def test_event_context_default_true(self):
|
||||
"""Test event_context defaults to True for Protocol v1."""
|
||||
caps = AgentRunnerCapabilities()
|
||||
|
||||
assert caps.event_context is True
|
||||
|
||||
|
||||
class TestSDKPermissionsProtocolV1:
|
||||
"""Test SDK permissions for Protocol v1."""
|
||||
|
||||
def test_permissions_new_fields(self):
|
||||
"""Test new permission fields for Protocol v1."""
|
||||
perms = AgentRunnerPermissions(
|
||||
models=["invoke", "stream", "rerank"],
|
||||
tools=["detail", "call"],
|
||||
knowledge_bases=["list", "retrieve"],
|
||||
history=["page", "search"],
|
||||
events=["get", "page"],
|
||||
artifacts=["metadata", "read"],
|
||||
storage=["plugin", "workspace", "binding"],
|
||||
)
|
||||
|
||||
assert perms.history == ["page", "search"]
|
||||
assert perms.events == ["get", "page"]
|
||||
assert perms.artifacts == ["metadata", "read"]
|
||||
assert perms.storage == ["plugin", "workspace", "binding"]
|
||||
|
||||
|
||||
class TestSDKResultProtocolV1:
|
||||
"""Test SDK AgentRunResult for Protocol v1."""
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ async def _register_session(
|
||||
*,
|
||||
run_id='run_1',
|
||||
conversation_id='conv_1',
|
||||
permissions=None,
|
||||
available_apis=None,
|
||||
):
|
||||
await session_registry.register(
|
||||
run_id=run_id,
|
||||
@@ -73,13 +73,13 @@ async def _register_session(
|
||||
plugin_identity='test/runner',
|
||||
resources=make_resources(),
|
||||
conversation_id=conversation_id,
|
||||
permissions=permissions or {},
|
||||
available_apis=available_apis or {},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_requires_manifest_permission(session_registry, db_engine):
|
||||
await _register_session(session_registry, permissions={'history': []})
|
||||
async def test_history_page_requires_runtime_capability(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'history_page': False})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||
|
||||
@@ -94,7 +94,7 @@ async def test_history_page_requires_manifest_permission(session_registry, db_en
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_page_rejects_cross_conversation(session_registry, db_engine):
|
||||
await _register_session(session_registry, permissions={'history': ['page']})
|
||||
await _register_session(session_registry, available_apis={'history_page': True})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
|
||||
|
||||
@@ -110,7 +110,7 @@ async def test_history_page_rejects_cross_conversation(session_registry, db_engi
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_search_rejects_filter_conversation_override(session_registry, db_engine):
|
||||
await _register_session(session_registry, permissions={'history': ['search']})
|
||||
await _register_session(session_registry, available_apis={'history_search': True})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
history_search = handler.actions[PluginToRuntimeAction.HISTORY_SEARCH.value]
|
||||
|
||||
@@ -126,8 +126,8 @@ async def test_history_search_rejects_filter_conversation_override(session_regis
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_requires_manifest_permission(session_registry, db_engine):
|
||||
await _register_session(session_registry, permissions={'events': []})
|
||||
async def test_event_page_requires_runtime_capability(session_registry, db_engine):
|
||||
await _register_session(session_registry, available_apis={'event_page': False})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
@@ -142,7 +142,7 @@ async def test_event_page_requires_manifest_permission(session_registry, db_engi
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_rejects_cross_conversation(session_registry, db_engine):
|
||||
await _register_session(session_registry, permissions={'events': ['page']})
|
||||
await _register_session(session_registry, available_apis={'event_page': True})
|
||||
handler = _handler(db_engine, session_registry)
|
||||
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
|
||||
|
||||
@@ -158,7 +158,7 @@ async def test_event_page_rejects_cross_conversation(session_registry, db_engine
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_get_returns_sdk_record_projection(session_registry, db_engine):
|
||||
await _register_session(session_registry, permissions={'events': ['get']})
|
||||
await _register_session(session_registry, available_apis={'event_get': True})
|
||||
store = EventLogStore(db_engine)
|
||||
event_id = await store.append_event(
|
||||
event_id='evt_projection_1',
|
||||
@@ -193,7 +193,7 @@ async def test_event_get_returns_sdk_record_projection(session_registry, db_engi
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_page_returns_sdk_page_projection(session_registry, db_engine):
|
||||
await _register_session(session_registry, permissions={'events': ['page']})
|
||||
await _register_session(session_registry, available_apis={'event_page': True})
|
||||
store = EventLogStore(db_engine)
|
||||
await store.append_event(
|
||||
event_id='evt_projection_page_1',
|
||||
|
||||
@@ -159,17 +159,19 @@ def make_descriptor() -> AgentRunnerDescriptor:
|
||||
"knowledge_retrieval": True,
|
||||
"skill_authoring": True,
|
||||
},
|
||||
permissions={
|
||||
"models": ["invoke", "stream"],
|
||||
"tools": ["detail", "call"],
|
||||
"knowledge_bases": ["list", "retrieve"],
|
||||
"history": ["page", "search"],
|
||||
"events": ["get", "page"],
|
||||
"artifacts": ["metadata", "read"],
|
||||
"storage": ["plugin"],
|
||||
},
|
||||
config_schema=[
|
||||
{"name": "model", "type": "model-fallback-selector"},
|
||||
{"name": "knowledge-bases", "type": "knowledge-base-multi-selector", "default": []},
|
||||
],
|
||||
permissions={
|
||||
"models": ["invoke", "stream"],
|
||||
"tools": ["list", "detail", "call"],
|
||||
"knowledge_bases": ["list", "retrieve"],
|
||||
"storage": ["plugin"],
|
||||
"files": [],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,13 +13,23 @@ from langbot.pkg.agent.runner.resource_builder import AgentResourceBuilder
|
||||
|
||||
|
||||
RUNNER_ID = 'plugin:test/runner/default'
|
||||
FULL_PERMISSIONS = {
|
||||
'models': ['invoke', 'stream', 'rerank'],
|
||||
'tools': ['detail', 'call'],
|
||||
'knowledge_bases': ['list', 'retrieve'],
|
||||
'history': ['page', 'search'],
|
||||
'events': ['get', 'page'],
|
||||
'artifacts': ['metadata', 'read'],
|
||||
'storage': ['plugin', 'workspace'],
|
||||
'files': ['config', 'knowledge'],
|
||||
}
|
||||
|
||||
|
||||
def make_descriptor(
|
||||
*,
|
||||
permissions: dict | None = None,
|
||||
config_schema: list[dict] | None = None,
|
||||
capabilities: dict | None = None,
|
||||
permissions: dict | None = None,
|
||||
) -> AgentRunnerDescriptor:
|
||||
return AgentRunnerDescriptor(
|
||||
id=RUNNER_ID,
|
||||
@@ -29,7 +39,7 @@ def make_descriptor(
|
||||
plugin_name='runner',
|
||||
runner_name='default',
|
||||
capabilities=capabilities or {},
|
||||
permissions=permissions or {'models': ['invoke', 'stream']},
|
||||
permissions=permissions if permissions is not None else FULL_PERMISSIONS,
|
||||
config_schema=config_schema or [],
|
||||
)
|
||||
|
||||
@@ -113,7 +123,6 @@ async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=get_model_by_uuid)
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(side_effect=get_rerank_model_by_uuid)
|
||||
descriptor = make_descriptor(
|
||||
permissions={'models': ['invoke', 'stream', 'rerank']},
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
{'name': 'aux-model', 'type': 'llm-model-selector'},
|
||||
@@ -137,16 +146,16 @@ async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_still_honors_manifest_permissions(app):
|
||||
"""Config-selected models should not bypass runner manifest permissions."""
|
||||
async def test_build_models_from_config_without_manifest_acl(app):
|
||||
"""Config-selected models are not projected without manifest model permissions."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(return_value=make_model(model_type='rerank'))
|
||||
descriptor = make_descriptor(
|
||||
permissions={'models': []},
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'model-fallback-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
permissions={},
|
||||
)
|
||||
query = make_query({
|
||||
'model': {'primary': 'primary', 'fallbacks': ['fallback']},
|
||||
@@ -156,19 +165,16 @@ async def test_build_models_still_honors_manifest_permissions(app):
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == []
|
||||
app.model_mgr.get_model_by_uuid.assert_not_awaited()
|
||||
app.model_mgr.get_rerank_model_by_uuid.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_authorizes_rerank_only_runner(app):
|
||||
"""A rerank-only runner should receive config-selected rerank models."""
|
||||
async def test_build_models_authorizes_rerank_and_llm_refs_from_config(app):
|
||||
"""Config-selected model references are projected regardless of method granularity."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
|
||||
return_value=make_model(model_type='rerank', provider='rerank-provider')
|
||||
)
|
||||
descriptor = make_descriptor(
|
||||
permissions={'models': ['rerank']},
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'llm-model-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
@@ -181,10 +187,39 @@ async def test_build_models_authorizes_rerank_only_runner(app):
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'llm', 'model_type': 'llm', 'provider': 'test-provider'},
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider'},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_models_manifest_permission_narrows_binding(app):
|
||||
"""Manifest model permissions narrower than binding should remove LLM grants."""
|
||||
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
|
||||
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
|
||||
return_value=make_model(model_type='rerank', provider='rerank-provider')
|
||||
)
|
||||
descriptor = make_descriptor(
|
||||
config_schema=[
|
||||
{'name': 'model', 'type': 'llm-model-selector'},
|
||||
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
|
||||
],
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'models': ['rerank'],
|
||||
},
|
||||
)
|
||||
query = make_query({
|
||||
'model': 'llm',
|
||||
'rerank-model': 'rerank',
|
||||
})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['models'] == [
|
||||
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider'},
|
||||
]
|
||||
app.model_mgr.get_model_by_uuid.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -212,10 +247,7 @@ async def test_build_models_deduplicates_query_and_config_models(app):
|
||||
async def test_build_tools_authorizes_query_declared_tools(app):
|
||||
"""Tools discovered by Pipeline preprocessing become run-scoped authorized resources."""
|
||||
descriptor = make_descriptor(
|
||||
permissions={
|
||||
'models': [],
|
||||
'tools': ['detail', 'call'],
|
||||
},
|
||||
capabilities={'tool_calling': True},
|
||||
)
|
||||
query = make_query(
|
||||
{},
|
||||
@@ -241,14 +273,32 @@ async def test_build_tools_authorizes_query_declared_tools(app):
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_tools_manifest_permission_denies_binding_tools(app):
|
||||
"""Binding tool grants should be removed when manifest does not request tools."""
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'tool_calling': True},
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'tools': [],
|
||||
},
|
||||
)
|
||||
query = make_query(
|
||||
{},
|
||||
use_funcs=[
|
||||
{'name': 'qa_plugin_echo', 'description': 'Echo test tool'},
|
||||
],
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['tools'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_knowledge_bases_unions_config_and_policy_grants(app):
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'knowledge_retrieval': True},
|
||||
permissions={
|
||||
'models': [],
|
||||
'knowledge_bases': ['retrieve'],
|
||||
},
|
||||
config_schema=[
|
||||
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'},
|
||||
],
|
||||
@@ -273,3 +323,43 @@ async def test_build_knowledge_bases_unions_config_and_policy_grants(app):
|
||||
{'kb_id': 'kb_config', 'kb_name': 'name-kb_config', 'kb_type': 'default'},
|
||||
{'kb_id': 'kb_policy', 'kb_name': 'name-kb_policy', 'kb_type': 'default'},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_knowledge_bases_manifest_permission_denies_binding_kbs(app):
|
||||
descriptor = make_descriptor(
|
||||
capabilities={'knowledge_retrieval': True},
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'knowledge_bases': [],
|
||||
},
|
||||
config_schema=[
|
||||
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'},
|
||||
],
|
||||
)
|
||||
query = make_query(
|
||||
{'knowledge-bases': ['kb_config']},
|
||||
variables={'_knowledge_base_uuids': ['kb_policy']},
|
||||
)
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['knowledge_bases'] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_storage_intersects_manifest_and_binding_policy(app):
|
||||
descriptor = make_descriptor(
|
||||
permissions={
|
||||
**FULL_PERMISSIONS,
|
||||
'storage': ['plugin'],
|
||||
},
|
||||
)
|
||||
query = make_query({})
|
||||
|
||||
resources = await build_resources(app, query, descriptor)
|
||||
|
||||
assert resources['storage'] == {
|
||||
'plugin_storage': True,
|
||||
'workspace_storage': False,
|
||||
}
|
||||
|
||||
@@ -14,12 +14,15 @@ class FakeApplication:
|
||||
"""Fake Application for testing."""
|
||||
def __init__(self):
|
||||
class FakeLogger:
|
||||
def __init__(self):
|
||||
self.warnings = []
|
||||
|
||||
def info(self, msg):
|
||||
pass
|
||||
def debug(self, msg):
|
||||
pass
|
||||
def warning(self, msg):
|
||||
pass
|
||||
self.warnings.append(msg)
|
||||
def error(self, msg):
|
||||
pass
|
||||
|
||||
@@ -67,7 +70,7 @@ class TestNormalizeMessageDelta:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_delta_missing_chunk(self):
|
||||
"""Normalize message.delta without chunk data."""
|
||||
"""Invalid message.delta payload is dropped."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
@@ -76,10 +79,9 @@ class TestNormalizeMessageDelta:
|
||||
'data': {},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await normalizer.normalize(result_dict, descriptor)
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert 'missing chunk data' in str(exc_info.value)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestNormalizeMessageCompleted:
|
||||
@@ -110,7 +112,7 @@ class TestNormalizeMessageCompleted:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalize_message_completed_missing_message(self):
|
||||
"""Normalize message.completed without message data."""
|
||||
"""Invalid message.completed payload is dropped."""
|
||||
normalizer = AgentResultNormalizer(FakeApplication())
|
||||
descriptor = make_descriptor()
|
||||
|
||||
@@ -119,10 +121,9 @@ class TestNormalizeMessageCompleted:
|
||||
'data': {},
|
||||
}
|
||||
|
||||
with pytest.raises(RunnerProtocolError) as exc_info:
|
||||
await normalizer.normalize(result_dict, descriptor)
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
|
||||
assert 'missing message data' in str(exc_info.value)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestNormalizeRunCompleted:
|
||||
@@ -260,13 +261,57 @@ class TestNormalizeNonMessageResults:
|
||||
'type': 'action.requested',
|
||||
'data': {
|
||||
'action': 'platform.message.edit',
|
||||
'parameters': {},
|
||||
'payload': {},
|
||||
},
|
||||
}
|
||||
|
||||
result = await normalizer.normalize(result_dict, descriptor)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_state_updated_payload_is_dropped(self):
|
||||
"""Invalid state.updated payload returns None with a warning."""
|
||||
app = FakeApplication()
|
||||
normalizer = AgentResultNormalizer(app)
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result = await normalizer.normalize(
|
||||
{
|
||||
'type': 'state.updated',
|
||||
'data': {
|
||||
'scope': 'invalid',
|
||||
'key': 'k',
|
||||
'value': 'v',
|
||||
},
|
||||
},
|
||||
descriptor,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert app.logger.warnings
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_artifact_created_payload_is_dropped(self):
|
||||
"""Invalid artifact.created payload returns None with a warning."""
|
||||
app = FakeApplication()
|
||||
normalizer = AgentResultNormalizer(app)
|
||||
descriptor = make_descriptor()
|
||||
|
||||
result = await normalizer.normalize(
|
||||
{
|
||||
'type': 'artifact.created',
|
||||
'data': {
|
||||
'artifact_id': 'artifact-1',
|
||||
'artifact_type': 'file',
|
||||
'content_base64': 'not base64',
|
||||
},
|
||||
},
|
||||
descriptor,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert app.logger.warnings
|
||||
|
||||
|
||||
class TestNormalizeInvalidResults:
|
||||
"""Tests for handling invalid results."""
|
||||
|
||||
@@ -63,7 +63,7 @@ class TestSessionRegistryBasic:
|
||||
query_id=1,
|
||||
plugin_identity='test/my-runner',
|
||||
resources=resources,
|
||||
permissions={'models': ['invoke']},
|
||||
available_apis={'history_page': True},
|
||||
conversation_id='conv_001',
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ class TestSessionRegistryBasic:
|
||||
assert session is not None
|
||||
authorization = session['authorization']
|
||||
assert authorization['conversation_id'] == 'conv_001'
|
||||
assert authorization['permissions'] == {'models': ['invoke']}
|
||||
assert authorization['available_apis'] == {'history_page': True}
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_001') is True
|
||||
assert registry.is_resource_allowed(session, 'model', 'model_late') is False
|
||||
assert registry.is_resource_allowed(session, 'storage', 'workspace') is False
|
||||
|
||||
@@ -14,6 +14,23 @@ from tests.factories import FakeApp
|
||||
|
||||
|
||||
DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default'
|
||||
_current_runner_class = None
|
||||
|
||||
|
||||
def _default_runner_class():
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
class DefaultRunner:
|
||||
name = 'local-agent'
|
||||
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
|
||||
async def run(self, query):
|
||||
yield Message(role='assistant', content='fake response')
|
||||
|
||||
return DefaultRunner
|
||||
|
||||
|
||||
def runner_pipeline_config(output_misc: dict) -> dict:
|
||||
@@ -47,21 +64,8 @@ def mock_circular_import_chain():
|
||||
make_pipeline_handler_import_mocks,
|
||||
get_handler_modules_to_clear,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.provider.message import Message
|
||||
|
||||
mocks = make_pipeline_handler_import_mocks()
|
||||
|
||||
# Create a default runner that yields a simple response
|
||||
class DefaultRunner:
|
||||
name = 'local-agent'
|
||||
def __init__(self, app, config):
|
||||
self.app = app
|
||||
self.config = config
|
||||
async def run(self, query):
|
||||
yield Message(role='assistant', content='fake response')
|
||||
|
||||
mocks['langbot.pkg.provider.runner'].preregistered_runners = [DefaultRunner]
|
||||
|
||||
clear = get_handler_modules_to_clear('chat')
|
||||
|
||||
with isolated_sys_modules(mocks=mocks, clear=clear):
|
||||
@@ -75,9 +79,7 @@ def fake_app():
|
||||
|
||||
class ProviderRunnerBackedOrchestrator:
|
||||
async def run_from_query(self, query):
|
||||
import sys
|
||||
|
||||
runner_class = sys.modules['langbot.pkg.provider.runner'].preregistered_runners[0]
|
||||
runner_class = _current_runner_class or _default_runner_class()
|
||||
runner = runner_class(app, {})
|
||||
async for result in runner.run(query):
|
||||
yield result
|
||||
@@ -103,10 +105,15 @@ def mock_event_ctx():
|
||||
@pytest.fixture
|
||||
def set_runner():
|
||||
"""Factory fixture to set a custom runner for tests."""
|
||||
global _current_runner_class
|
||||
previous = _current_runner_class
|
||||
|
||||
def _set_runner(runner_class):
|
||||
import sys
|
||||
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
|
||||
return _set_runner
|
||||
global _current_runner_class
|
||||
_current_runner_class = runner_class
|
||||
|
||||
yield _set_runner
|
||||
_current_runner_class = previous
|
||||
|
||||
|
||||
# ============== CACHED LAZY IMPORTS ==============
|
||||
|
||||
@@ -1,353 +0,0 @@
|
||||
"""
|
||||
Unit tests for N8nServiceAPIRunner._process_response
|
||||
|
||||
Tests cover four scenarios:
|
||||
- Stream adapter + n8n stream format (type:item/end)
|
||||
- Stream adapter + n8n plain JSON
|
||||
- Non-stream adapter + n8n stream format
|
||||
- Non-stream adapter + n8n plain JSON
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
# Break the circular import chain while importing n8nsvapi:
|
||||
# n8nsvapi → runner → app → pipelinemgr → all runners → runner (partially init)
|
||||
# The stubs are restored in a ``finally`` block so this module does NOT pollute
|
||||
# sys.modules for other test modules (e.g. ones importing the real
|
||||
# LocalAgentRunner, which would otherwise inherit ``object`` and break).
|
||||
# Mirrors master's intent but uses try/finally so a raised import doesn't
|
||||
# leave the global namespace in a stubbed state, and includes
|
||||
# ``langbot.pkg.utils.httpclient`` which master didn't stub.
|
||||
_runner_stub = MagicMock()
|
||||
_runner_stub.runner_class = lambda name: (lambda cls: cls) # no-op decorator
|
||||
_runner_stub.RequestRunner = object
|
||||
_import_stubs = {
|
||||
'langbot.pkg.provider.runner': _runner_stub,
|
||||
'langbot.pkg.core.app': MagicMock(),
|
||||
'langbot.pkg.utils.httpclient': MagicMock(),
|
||||
}
|
||||
_saved_modules = {name: sys.modules.get(name) for name in _import_stubs}
|
||||
for _name, _stub in _import_stubs.items():
|
||||
sys.modules[_name] = _stub
|
||||
try:
|
||||
from langbot.pkg.provider.runners.n8nsvapi import N8nServiceAPIRunner
|
||||
finally:
|
||||
for _name, _original in _saved_modules.items():
|
||||
if _original is None:
|
||||
sys.modules.pop(_name, None)
|
||||
else:
|
||||
sys.modules[_name] = _original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_runner(output_key: str = 'response') -> N8nServiceAPIRunner:
|
||||
ap = Mock()
|
||||
ap.logger = Mock()
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'n8n-service-api': {
|
||||
'webhook-url': 'http://test-n8n/webhook',
|
||||
'output-key': output_key,
|
||||
'auth-type': 'none',
|
||||
}
|
||||
}
|
||||
}
|
||||
return N8nServiceAPIRunner(ap, pipeline_config)
|
||||
|
||||
|
||||
def make_mock_response(chunks: list[bytes | str], status: int = 200):
|
||||
"""Build a minimal aiohttp.ClientResponse mock with iter_chunked support."""
|
||||
response = Mock()
|
||||
response.status = status
|
||||
|
||||
async def iter_chunked(size):
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
response.content = Mock()
|
||||
response.content.iter_chunked = iter_chunked
|
||||
return response
|
||||
|
||||
|
||||
async def collect_chunks(runner: N8nServiceAPIRunner, chunks: list[bytes | str]):
|
||||
"""Run _process_response and collect all yielded MessageChunks."""
|
||||
response = make_mock_response(chunks)
|
||||
result = []
|
||||
async for chunk in runner._process_response(response):
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _process_response: stream format (type:item/end)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_single_item():
|
||||
"""Single item + end in one chunk yields final chunk with full content."""
|
||||
runner = make_runner()
|
||||
data = b'{"type":"item","content":"hello"}{"type":"end"}'
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'hello'
|
||||
assert chunks[0].msg_sequence == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_multi_item_accumulates():
|
||||
"""Multiple items accumulate into full_content."""
|
||||
runner = make_runner()
|
||||
chunks_data = [
|
||||
b'{"type":"item","content":"foo"}',
|
||||
b'{"type":"item","content":"bar"}',
|
||||
b'{"type":"end"}',
|
||||
]
|
||||
|
||||
chunks = await collect_chunks(runner, chunks_data)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'foobar'
|
||||
assert chunks[0].msg_sequence == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_batches_every_8_items():
|
||||
"""Every 8th item triggers an intermediate yield before the final."""
|
||||
runner = make_runner()
|
||||
items = [f'{{"type":"item","content":"{i}"}}' for i in range(8)]
|
||||
items.append('{"type":"end"}')
|
||||
data = ''.join(items).encode()
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].is_final is False
|
||||
assert chunks[0].content == '01234567'
|
||||
assert chunks[0].msg_sequence == 1
|
||||
assert chunks[1].is_final is True
|
||||
assert chunks[1].content == '01234567'
|
||||
assert chunks[1].msg_sequence == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_split_across_network_chunks():
|
||||
"""JSON split across multiple network chunks is reassembled correctly."""
|
||||
runner = make_runner()
|
||||
part1 = b'{"type":"item","con'
|
||||
part2 = b'tent":"world"}{"type":"end"}'
|
||||
|
||||
chunks = await collect_chunks(runner, [part1, part2])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'world'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_format_no_spurious_empty_yield():
|
||||
"""chunk_idx==0 guard prevents spurious empty yield before any item is received."""
|
||||
runner = make_runner()
|
||||
# Send some non-stream JSON first, then stream
|
||||
data = b'{"type":"item","content":"x"}{"type":"end"}'
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].content == 'x'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _process_response: plain JSON fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_json_with_output_key():
|
||||
"""Plain JSON with matching output_key extracts value via output_key."""
|
||||
runner = make_runner(output_key='response')
|
||||
data = json.dumps({'response': 'hello world'}).encode()
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'hello world'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_json_output_key_not_found():
|
||||
"""Plain JSON without output_key falls back to entire JSON string."""
|
||||
runner = make_runner(output_key='response')
|
||||
payload = {'other_key': 'hello'}
|
||||
data = json.dumps(payload).encode()
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert json.loads(chunks[0].content) == payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_json_output_key_empty_string():
|
||||
"""output_key present but value is empty string — returns empty string, not whole JSON."""
|
||||
runner = make_runner(output_key='response')
|
||||
data = json.dumps({'response': ''}).encode()
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_json_non_dict_response():
|
||||
"""Plain JSON array falls back to raw text."""
|
||||
runner = make_runner()
|
||||
data = b'["a", "b"]'
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == '["a", "b"]'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_returns_raw_text():
|
||||
"""Non-JSON response returns raw text as-is."""
|
||||
runner = make_runner()
|
||||
data = b'plain text response'
|
||||
|
||||
chunks = await collect_chunks(runner, [data])
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].is_final is True
|
||||
assert chunks[0].content == 'plain text response'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _call_webhook: output type depends on is_stream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_query(is_stream: bool):
|
||||
"""Build a minimal Query mock."""
|
||||
query = Mock()
|
||||
query.adapter = AsyncMock()
|
||||
query.adapter.is_stream_output_supported = AsyncMock(return_value=is_stream)
|
||||
|
||||
session = Mock()
|
||||
session.using_conversation = Mock()
|
||||
session.using_conversation.uuid = 'test-uuid'
|
||||
session.launcher_type = Mock()
|
||||
session.launcher_type.value = 'person'
|
||||
session.launcher_id = '12345'
|
||||
query.session = session
|
||||
|
||||
query.user_message = Mock()
|
||||
query.user_message.content = 'hi'
|
||||
query.variables = {}
|
||||
return query
|
||||
|
||||
|
||||
def make_http_session_mock(response_bytes: bytes, status: int = 200):
|
||||
"""Mock httpclient.get_session() returning a session whose post() yields response_bytes."""
|
||||
mock_response = make_mock_response([response_bytes], status=status)
|
||||
mock_response.status = status
|
||||
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_cm)
|
||||
return mock_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_webhook_nonstream_adapter_plain_json():
|
||||
"""Non-stream adapter + plain JSON → single Message with output_key value."""
|
||||
runner = make_runner(output_key='response')
|
||||
query = make_query(is_stream=False)
|
||||
http_session = make_http_session_mock(json.dumps({'response': 'result text'}).encode())
|
||||
|
||||
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
|
||||
results = []
|
||||
async for msg in runner._call_webhook(query):
|
||||
results.append(msg)
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], provider_message.Message)
|
||||
assert results[0].content == 'result text'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_webhook_stream_adapter_stream_format():
|
||||
"""Stream adapter + stream format → MessageChunks, last is_final."""
|
||||
runner = make_runner()
|
||||
query = make_query(is_stream=True)
|
||||
data = b'{"type":"item","content":"hi"}{"type":"end"}'
|
||||
http_session = make_http_session_mock(data)
|
||||
|
||||
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
|
||||
results = []
|
||||
async for msg in runner._call_webhook(query):
|
||||
results.append(msg)
|
||||
|
||||
assert all(isinstance(r, provider_message.MessageChunk) for r in results)
|
||||
assert results[-1].is_final is True
|
||||
assert results[-1].content == 'hi'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_webhook_stream_adapter_plain_json():
|
||||
"""Stream adapter + plain JSON → single MessageChunk with is_final=True."""
|
||||
runner = make_runner(output_key='response')
|
||||
query = make_query(is_stream=True)
|
||||
data = json.dumps({'response': 'fallback'}).encode()
|
||||
http_session = make_http_session_mock(data)
|
||||
|
||||
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
|
||||
results = []
|
||||
async for msg in runner._call_webhook(query):
|
||||
results.append(msg)
|
||||
|
||||
assert all(isinstance(r, provider_message.MessageChunk) for r in results)
|
||||
assert results[-1].is_final is True
|
||||
assert results[-1].content == 'fallback'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_webhook_nonstream_adapter_stream_format():
|
||||
"""Non-stream adapter + stream format → single Message with accumulated content."""
|
||||
runner = make_runner()
|
||||
query = make_query(is_stream=False)
|
||||
data = b'{"type":"item","content":"foo"}{"type":"item","content":"bar"}{"type":"end"}'
|
||||
http_session = make_http_session_mock(data)
|
||||
|
||||
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
|
||||
results = []
|
||||
async for msg in runner._call_webhook(query):
|
||||
results.append(msg)
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], provider_message.Message)
|
||||
assert results[0].content == 'foobar'
|
||||
@@ -73,8 +73,8 @@ def make_host_model_runner_descriptor(
|
||||
'skill_authoring': skill_authoring,
|
||||
},
|
||||
permissions={
|
||||
'models': ['list', 'invoke', 'stream'],
|
||||
'tools': ['list', 'detail', 'call'],
|
||||
'models': ['invoke', 'stream'],
|
||||
'tools': ['detail', 'call'],
|
||||
'knowledge_bases': ['list', 'retrieve'],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
"""Tests for DifyServiceAPIRunner pure utility methods.
|
||||
|
||||
Tests the helper methods that don't require real Dify API calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDifyExtractTextOutput:
|
||||
"""Tests for _extract_dify_text_output method."""
|
||||
|
||||
def _create_runner(self):
|
||||
"""Create runner instance."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
|
||||
|
||||
mock_app = MagicMock()
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'dify-service-api': {
|
||||
'app-type': 'chat',
|
||||
'api-key': 'test-key',
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}}
|
||||
}
|
||||
|
||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
runner.dify_client = MagicMock()
|
||||
|
||||
return runner
|
||||
|
||||
def test_extract_none_value(self):
|
||||
"""None returns empty string."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output(None)
|
||||
|
||||
assert result == ''
|
||||
|
||||
def test_extract_string_value(self):
|
||||
"""Plain string is returned."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output('plain text')
|
||||
|
||||
assert result == 'plain text'
|
||||
|
||||
def test_extract_dict_with_content(self):
|
||||
"""Dict with 'content' key extracts content."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output({'content': 'extracted content'})
|
||||
|
||||
assert result == 'extracted content'
|
||||
|
||||
def test_extract_dict_without_content(self):
|
||||
"""Dict without 'content' key is JSON dumped."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output({'key': 'value'})
|
||||
|
||||
assert 'key' in result
|
||||
assert 'value' in result
|
||||
|
||||
def test_extract_json_string_with_content(self):
|
||||
"""JSON string with 'content' key extracts content."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output('{"content": "json content"}')
|
||||
|
||||
assert result == 'json content'
|
||||
|
||||
def test_extract_json_string_without_content(self):
|
||||
"""JSON string without 'content' key returns original."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output('{"other": "value"}')
|
||||
|
||||
assert '{"other": "value"}' in result
|
||||
|
||||
def test_extract_whitespace_string(self):
|
||||
"""Whitespace string returns empty."""
|
||||
runner = self._create_runner()
|
||||
|
||||
result = runner._extract_dify_text_output(' ')
|
||||
|
||||
assert result == ''
|
||||
|
||||
|
||||
class TestDifyRunnerConfigValidation:
|
||||
"""Tests for runner config validation."""
|
||||
|
||||
def test_invalid_app_type_raises(self):
|
||||
"""Invalid app-type raises DifyAPIError."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
|
||||
from langbot.libs.dify_service_api.v1.errors import DifyAPIError
|
||||
|
||||
mock_app = MagicMock()
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'dify-service-api': {
|
||||
'app-type': 'invalid-type',
|
||||
'api-key': 'test',
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}}
|
||||
}
|
||||
|
||||
with pytest.raises(DifyAPIError, match='不支持'):
|
||||
DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
|
||||
def test_valid_app_types(self):
|
||||
"""Valid app-types don't raise."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
for app_type in ['chat', 'agent', 'workflow']:
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'dify-service-api': {
|
||||
'app-type': app_type,
|
||||
'api-key': 'test',
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}}
|
||||
}
|
||||
|
||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
# Should not raise
|
||||
assert runner is not None
|
||||
|
||||
|
||||
class TestDifyRunnerInit:
|
||||
"""Tests for runner initialization."""
|
||||
|
||||
def test_runner_stores_config(self):
|
||||
"""Runner stores pipeline_config."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
|
||||
|
||||
mock_app = MagicMock()
|
||||
pipeline_config = {
|
||||
'ai': {
|
||||
'dify-service-api': {
|
||||
'app-type': 'chat',
|
||||
'api-key': 'test-key',
|
||||
'base-url': 'https://api.dify.ai',
|
||||
}
|
||||
},
|
||||
'output': {'misc': {}}
|
||||
}
|
||||
|
||||
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
|
||||
|
||||
assert runner.pipeline_config == pipeline_config
|
||||
assert runner.ap == mock_app
|
||||
@@ -1,242 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
||||
|
||||
|
||||
class RecordingProvider:
|
||||
def __init__(self):
|
||||
self.requests: list[dict] = []
|
||||
|
||||
async def invoke_llm(self, query, model, messages, funcs, extra_args=None, remove_think=None):
|
||||
self.requests.append(
|
||||
{
|
||||
'messages': list(messages),
|
||||
'funcs': list(funcs),
|
||||
'remove_think': remove_think,
|
||||
}
|
||||
)
|
||||
|
||||
if len(self.requests) == 1:
|
||||
return provider_message.Message(
|
||||
role='assistant',
|
||||
content='Let me calculate that exactly.',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name='exec',
|
||||
arguments=json.dumps(
|
||||
{'command': ("python - <<'PY'\nnums = [1, 2, 3, 4]\nprint(sum(nums) / len(nums))\nPY")}
|
||||
),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
tool_result = json.loads(messages[-1].content)
|
||||
return provider_message.Message(
|
||||
role='assistant',
|
||||
content=f'The average is {tool_result["stdout"]}.',
|
||||
)
|
||||
|
||||
|
||||
class RecordingStreamProvider:
|
||||
def __init__(self):
|
||||
self.stream_requests: list[dict] = []
|
||||
|
||||
def invoke_llm_stream(self, query, model, messages, funcs, extra_args=None, remove_think=None):
|
||||
self.stream_requests.append(
|
||||
{
|
||||
'messages': list(messages),
|
||||
'funcs': list(funcs),
|
||||
'remove_think': remove_think,
|
||||
}
|
||||
)
|
||||
|
||||
async def _stream():
|
||||
if len(self.stream_requests) == 1:
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(
|
||||
name='exec',
|
||||
arguments=json.dumps({'command': "python -c 'print(1)'"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
is_final=True,
|
||||
)
|
||||
return
|
||||
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content='Tool execution failed.',
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
def make_query() -> pipeline_query.Query:
|
||||
adapter = AsyncMock()
|
||||
adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
||||
|
||||
return pipeline_query.Query.model_construct(
|
||||
query_id='avg-query',
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_chain=[],
|
||||
message_event=None,
|
||||
adapter=adapter,
|
||||
pipeline_uuid='pipeline-uuid',
|
||||
bot_uuid='bot-uuid',
|
||||
pipeline_config={
|
||||
'ai': {
|
||||
'runner': {'runner': 'local-agent'},
|
||||
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
|
||||
},
|
||||
'output': {'misc': {'remove-think': False}},
|
||||
},
|
||||
prompt=SimpleNamespace(messages=[]),
|
||||
messages=[],
|
||||
user_message=provider_message.Message(
|
||||
role='user',
|
||||
content='Please calculate the average of 1, 2, 3, and 4.',
|
||||
),
|
||||
use_funcs=[SimpleNamespace(name='exec')],
|
||||
use_llm_model_uuid='test-model-uuid',
|
||||
variables={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localagent_uses_exec_for_exact_calculation():
|
||||
provider = RecordingProvider()
|
||||
model = SimpleNamespace(
|
||||
provider=provider,
|
||||
model_entity=SimpleNamespace(
|
||||
uuid='test-model-uuid',
|
||||
name='test-model',
|
||||
abilities=['func_call'],
|
||||
extra_args={},
|
||||
),
|
||||
)
|
||||
|
||||
tool_manager = SimpleNamespace(
|
||||
execute_func_call=AsyncMock(
|
||||
return_value={
|
||||
'session_id': 'avg-query',
|
||||
'backend': 'podman',
|
||||
'status': 'completed',
|
||||
'ok': True,
|
||||
'exit_code': 0,
|
||||
'stdout': '2.5',
|
||||
'stderr': '',
|
||||
'duration_ms': 18,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
app = SimpleNamespace(
|
||||
logger=Mock(),
|
||||
model_mgr=SimpleNamespace(get_model_by_uuid=AsyncMock(return_value=model)),
|
||||
tool_mgr=tool_manager,
|
||||
rag_mgr=SimpleNamespace(),
|
||||
box_service=SimpleNamespace(
|
||||
get_system_guidance=Mock(
|
||||
return_value=(
|
||||
'When the exec tool is available, use it for exact calculations, statistics, '
|
||||
'structured data parsing, and code execution instead of estimating mentally. '
|
||||
'Unless the user explicitly asks for the script, code, or implementation details, '
|
||||
'do not include the generated script in the final answer. '
|
||||
'A default workspace is mounted at /workspace for file tasks.'
|
||||
)
|
||||
),
|
||||
),
|
||||
skill_mgr=SimpleNamespace(
|
||||
get_skills_for_pipeline=AsyncMock(return_value=[]),
|
||||
detect_skill_activation=AsyncMock(return_value=None),
|
||||
build_activation_prompt=Mock(return_value=None),
|
||||
),
|
||||
)
|
||||
|
||||
runner = LocalAgentRunner(app, pipeline_config={})
|
||||
query = make_query()
|
||||
|
||||
results = [message async for message in runner.run(query)]
|
||||
|
||||
assert [message.role for message in results] == ['assistant', 'tool', 'assistant']
|
||||
assert results[-1].content == 'The average is 2.5.'
|
||||
|
||||
tool_manager.execute_func_call.assert_awaited_once()
|
||||
tool_name, tool_parameters = tool_manager.execute_func_call.await_args.args[:2]
|
||||
assert tool_name == 'exec'
|
||||
assert 'print(sum(nums) / len(nums))' in tool_parameters['command']
|
||||
|
||||
first_request = provider.requests[0]
|
||||
assert any(
|
||||
message.role == 'system'
|
||||
and 'exec' in str(message.content)
|
||||
and 'exact calculations' in str(message.content)
|
||||
and 'Unless the user explicitly asks for the script' in str(message.content)
|
||||
and '/workspace' in str(message.content)
|
||||
for message in first_request['messages']
|
||||
)
|
||||
assert [tool.name for tool in first_request['funcs']] == ['exec']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localagent_streaming_tool_error_yields_message_chunks():
|
||||
provider = RecordingStreamProvider()
|
||||
model = SimpleNamespace(
|
||||
provider=provider,
|
||||
model_entity=SimpleNamespace(
|
||||
uuid='test-model-uuid',
|
||||
name='test-model',
|
||||
abilities=['func_call'],
|
||||
extra_args={},
|
||||
),
|
||||
)
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.is_stream_output_supported = AsyncMock(return_value=True)
|
||||
|
||||
query = make_query()
|
||||
query.adapter = adapter
|
||||
|
||||
app = SimpleNamespace(
|
||||
logger=Mock(),
|
||||
model_mgr=SimpleNamespace(get_model_by_uuid=AsyncMock(return_value=model)),
|
||||
tool_mgr=SimpleNamespace(execute_func_call=AsyncMock(side_effect=RuntimeError('boom'))),
|
||||
rag_mgr=SimpleNamespace(),
|
||||
box_service=SimpleNamespace(
|
||||
get_system_guidance=Mock(return_value='sandbox guidance'),
|
||||
),
|
||||
skill_mgr=SimpleNamespace(
|
||||
get_skills_for_pipeline=AsyncMock(return_value=[]),
|
||||
detect_skill_activation=AsyncMock(return_value=None),
|
||||
build_activation_prompt=Mock(return_value=None),
|
||||
),
|
||||
)
|
||||
|
||||
runner = LocalAgentRunner(app, pipeline_config={})
|
||||
|
||||
results = [message async for message in runner.run(query)]
|
||||
|
||||
assert all(isinstance(message, provider_message.MessageChunk) for message in results)
|
||||
assert any(message.role == 'tool' and message.content == 'err: boom' for message in results)
|
||||
@@ -21,7 +21,6 @@ from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
from langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl import ModelScopeChatCompletions
|
||||
from langbot.pkg.provider.modelmgr.token import TokenManager
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
||||
|
||||
|
||||
DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default'
|
||||
@@ -43,8 +42,8 @@ class FakeAgentRunnerRegistry:
|
||||
],
|
||||
capabilities={'tool_calling': True, 'knowledge_retrieval': True, 'multimodal_input': True},
|
||||
permissions={
|
||||
'models': ['list', 'invoke', 'stream'],
|
||||
'tools': ['list', 'detail', 'call'],
|
||||
'models': ['invoke', 'stream'],
|
||||
'tools': ['detail', 'call'],
|
||||
'knowledge_bases': ['list', 'retrieve'],
|
||||
},
|
||||
)
|
||||
@@ -320,8 +319,3 @@ async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline()
|
||||
processed_query = result.new_query
|
||||
|
||||
assert processed_query.use_llm_model_uuid == model_uuid
|
||||
|
||||
runner = SimpleNamespace(ap=ap, pipeline_config=pipeline_config)
|
||||
candidates = await LocalAgentRunner._get_model_candidates(runner, processed_query)
|
||||
|
||||
assert [model.model_entity.uuid for model in candidates] == [model_uuid]
|
||||
|
||||
@@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langbot_plugin.api.entities.builtin.agent_runner.manifest import (
|
||||
AgentRunnerCapabilities,
|
||||
AgentRunnerPermissions,
|
||||
)
|
||||
from langbot_plugin.api.entities.builtin.pipeline.query import Query
|
||||
from langbot_plugin.api.entities.builtin.platform.entities import Friend
|
||||
from langbot_plugin.api.entities.builtin.platform.events import FriendMessage
|
||||
@@ -24,22 +28,23 @@ class _FakeRunnerDescriptor:
|
||||
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []},
|
||||
]
|
||||
permissions = {
|
||||
'models': ['list', 'invoke', 'stream'],
|
||||
'tools': ['list', 'detail', 'call'],
|
||||
'models': ['invoke', 'stream'],
|
||||
'tools': ['detail', 'call'],
|
||||
'knowledge_bases': ['list', 'retrieve'],
|
||||
}
|
||||
capabilities = {
|
||||
'tool_calling': True,
|
||||
'knowledge_retrieval': True,
|
||||
'multimodal_input': True,
|
||||
'skill_authoring': True,
|
||||
}
|
||||
permissions = AgentRunnerPermissions.model_validate(permissions)
|
||||
capabilities = AgentRunnerCapabilities(
|
||||
tool_calling=True,
|
||||
knowledge_retrieval=True,
|
||||
multimodal_input=True,
|
||||
skill_authoring=True,
|
||||
)
|
||||
|
||||
def supports_tool_calling(self):
|
||||
return self.capabilities.get('tool_calling', False)
|
||||
return self.capabilities.tool_calling
|
||||
|
||||
def supports_knowledge_retrieval(self):
|
||||
return self.capabilities.get('knowledge_retrieval', False)
|
||||
return self.capabilities.knowledge_retrieval
|
||||
|
||||
|
||||
def _make_query() -> Query:
|
||||
|
||||
@@ -142,10 +142,6 @@ def make_pipeline_handler_import_mocks() -> dict[str, MagicMock]:
|
||||
# Mock core.app - Application class is referenced but not instantiated
|
||||
mock_app = MagicMock()
|
||||
|
||||
# Mock provider.runner - has preregistered_runners attribute
|
||||
mock_runner = MagicMock()
|
||||
mock_runner.preregistered_runners = [] # Empty by default, tests override
|
||||
|
||||
# Mock utils.importutil - prevents auto-import of runners
|
||||
mock_importutil = MagicMock()
|
||||
mock_importutil.import_modules_in_pkg = lambda pkg: None
|
||||
@@ -157,19 +153,11 @@ def make_pipeline_handler_import_mocks() -> dict[str, MagicMock]:
|
||||
'langbot.pkg.pipeline.controller': MagicMock(),
|
||||
'langbot.pkg.pipeline.pipelinemgr': MagicMock(),
|
||||
'langbot.pkg.pipeline.process.process': MagicMock(),
|
||||
'langbot.pkg.provider.runner': mock_runner,
|
||||
'langbot.pkg.utils.importutil': mock_importutil,
|
||||
}
|
||||
|
||||
|
||||
# Package attributes that need to be updated alongside sys.modules mocking.
|
||||
# When Python imports a submodule (e.g., langbot.pkg.provider.runner), it
|
||||
# automatically sets an attribute on the parent package. The import statement
|
||||
# `from ....provider import runner` gets this attribute, not sys.modules directly.
|
||||
# This dict maps mock module names to the parent packages that need attribute updates.
|
||||
_PACKAGE_ATTRIBUTE_UPDATES: dict[str, tuple[str, str]] = {
|
||||
'langbot.pkg.provider.runner': ('langbot.pkg.provider', 'runner'),
|
||||
}
|
||||
_PACKAGE_ATTRIBUTE_UPDATES: dict[str, tuple[str, str]] = {}
|
||||
|
||||
|
||||
def get_handler_modules_to_clear(handler_name: str) -> list[str]:
|
||||
@@ -190,4 +178,4 @@ def get_handler_modules_to_clear(handler_name: str) -> list[str]:
|
||||
'langbot.pkg.pipeline.process.handler',
|
||||
'langbot.pkg.pipeline.process.handlers',
|
||||
f'langbot.pkg.pipeline.process.handlers.{handler_name}',
|
||||
]
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user