feat(agent-runner): enforce typed host permissions

This commit is contained in:
huanghuoguoguo
2026-06-10 22:36:23 +08:00
parent 8938ef7412
commit ea96d37e60
41 changed files with 584 additions and 3862 deletions

View File

@@ -43,7 +43,7 @@ def make_session(
plugin_identity: str = 'test/test-runner',
resources: dict | None = None,
conversation_id: str | None = None,
permissions: dict[str, list[str]] | None = None,
available_apis: dict[str, bool] | None = None,
state_policy: dict[str, typing.Any] | None = None,
state_context: dict[str, typing.Any] | None = None,
) -> dict[str, typing.Any]:
@@ -62,7 +62,7 @@ def make_session(
import time
now = int(time.time())
res = resources if resources is not None else make_resources()
perms = permissions if permissions is not None else {}
apis = available_apis if available_apis is not None else {}
policy = (
state_policy
if state_policy is not None
@@ -85,7 +85,7 @@ def make_session(
'plugin_identity': plugin_identity,
'authorization': {
'resources': res,
'permissions': perms,
'available_apis': apis,
'conversation_id': conversation_id,
'state_policy': policy,
'state_context': context,

View File

@@ -212,7 +212,7 @@ class TestArtifactAccessValidation:
return make_session(
run_id="run_001",
conversation_id=conversation_id,
permissions={"artifacts": ["metadata", "read"]},
available_apis={"artifact_metadata": True, "artifact_read": True},
)
def _call_validate(self, session, metadata, operation="metadata"):
@@ -298,33 +298,23 @@ class TestArtifactAccessValidation:
class TestContextAccessArtifactAPIs:
"""Test ContextAccess reflects artifact API permissions."""
"""Test ContextAccess reflects runtime artifact API availability."""
@pytest.mark.asyncio
async def test_context_access_has_artifact_apis_when_permitted(self):
"""Test ContextAccess shows artifact APIs when permissions allow."""
# This tests the context builder logic
# When artifact permissions include 'metadata' and 'read',
# available_apis should reflect that
permissions = {"artifacts": ["metadata", "read"]}
"""Artifact APIs are exposed through run-scoped available_apis."""
available_apis = {"artifact_metadata": True, "artifact_read": True}
# Check that permissions are properly interpreted
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
artifact_read_enabled = "read" in permissions.get("artifacts", [])
assert artifact_metadata_enabled is True
assert artifact_read_enabled is True
assert available_apis["artifact_metadata"] is True
assert available_apis["artifact_read"] is True
@pytest.mark.asyncio
async def test_context_access_no_artifact_apis_without_permission(self):
"""Test ContextAccess hides artifact APIs when permissions denied."""
permissions = {"artifacts": []}
"""Artifact APIs are absent when the run did not receive them."""
available_apis = {}
artifact_metadata_enabled = "metadata" in permissions.get("artifacts", [])
artifact_read_enabled = "read" in permissions.get("artifacts", [])
assert artifact_metadata_enabled is False
assert artifact_read_enabled is False
assert available_apis.get("artifact_metadata", False) is False
assert available_apis.get("artifact_read", False) is False
class TestArtifactMetadataFieldAlignment:
@@ -376,8 +366,8 @@ class TestArtifactMetadataFieldAlignment:
assert "storage_type" not in result
class TestSessionRegistryPermissions:
"""Test that session registry stores and retrieves permissions correctly."""
class TestSessionRegistryAvailableAPIs:
"""Test that session registry stores and retrieves available APIs correctly."""
@pytest.fixture
def session_registry(self):
@@ -387,8 +377,8 @@ class TestSessionRegistryPermissions:
return get_session_registry()
@pytest.mark.asyncio
async def test_register_stores_permissions(self, session_registry):
"""Test that register() stores permissions from descriptor."""
async def test_register_stores_available_apis(self, session_registry):
"""Test that register() stores runtime API availability."""
await session_registry.register(
run_id="run_001",
runner_id="plugin:author/plugin/runner",
@@ -402,24 +392,26 @@ class TestSessionRegistryPermissions:
"storage": {"plugin_storage": True, "workspace_storage": False},
"platform_capabilities": {},
},
permissions={
"artifacts": ["metadata", "read"],
"history": ["page"],
"events": ["get"],
available_apis={
"artifact_metadata": True,
"artifact_read": True,
"history_page": True,
"event_get": True,
},
conversation_id="conv_001",
)
session = await session_registry.get("run_001")
assert session is not None
permissions = session["authorization"]["permissions"]
assert permissions["artifacts"] == ["metadata", "read"]
assert permissions["history"] == ["page"]
assert permissions["events"] == ["get"]
available_apis = session["authorization"]["available_apis"]
assert available_apis["artifact_metadata"] is True
assert available_apis["artifact_read"] is True
assert available_apis["history_page"] is True
assert available_apis["event_get"] is True
@pytest.mark.asyncio
async def test_register_with_empty_permissions(self, session_registry):
"""Test that register() handles empty permissions."""
async def test_register_with_empty_available_apis(self, session_registry):
"""Test that register() handles empty API availability."""
await session_registry.register(
run_id="run_002",
runner_id="plugin:author/plugin/runner",
@@ -433,13 +425,13 @@ class TestSessionRegistryPermissions:
"storage": {"plugin_storage": True, "workspace_storage": False},
"platform_capabilities": {},
},
permissions={},
available_apis={},
conversation_id="conv_001",
)
session = await session_registry.get("run_002")
assert session is not None
assert session["authorization"]["permissions"] == {}
assert session["authorization"]["available_apis"] == {}
class TestArtifactStoreRealSQLite:

View File

@@ -11,6 +11,7 @@ import pytest
from unittest.mock import MagicMock
from langbot.pkg.agent.runner.context_builder import AgentRunContextBuilder
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope, StatePolicy
from langbot_plugin.api.entities.builtin.agent_runner.event import ActorContext
from langbot_plugin.api.entities.builtin.agent_runner.input import AgentInput
@@ -25,6 +26,27 @@ class MockApplication:
self.persistence_mgr.get_db_engine = MagicMock()
def make_descriptor(
permissions: dict | None = None,
) -> AgentRunnerDescriptor:
return AgentRunnerDescriptor(
id='plugin:test/runner/default',
source='plugin',
label={'en_US': 'Test Runner'},
plugin_author='test',
plugin_name='runner',
runner_name='default',
permissions=permissions
if permissions is not None
else {
'history': ['page', 'search'],
'events': ['get', 'page'],
'artifacts': ['metadata', 'read'],
'storage': ['plugin'],
},
)
class TestContextAccessStateDetermination:
"""Tests for ContextAccess.state field determination - real calls to _build_context_access."""
@@ -54,10 +76,7 @@ class TestContextAccessStateDetermination:
@pytest.fixture
def mock_descriptor(self):
"""Create mock runner descriptor."""
descriptor = MagicMock()
descriptor.id = 'plugin:test/runner/default'
descriptor.permissions = {}
return descriptor
return make_descriptor()
@pytest.mark.asyncio
async def test_enable_state_true_with_scopes_sets_state_true(self, mock_app, mock_event, mock_descriptor):
@@ -237,7 +256,7 @@ class TestBindingWithStatePolicy:
class TestContextAccessOtherAPIs:
"""Tests for other available_apis fields based on permissions."""
"""Tests for other available_apis fields based on run scope."""
@pytest.fixture
def mock_app(self):
@@ -245,16 +264,12 @@ class TestContextAccessOtherAPIs:
return MockApplication()
@pytest.mark.asyncio
async def test_history_apis_based_on_permissions(self, mock_app):
"""History APIs availability based on runner permissions."""
async def test_history_apis_enabled_with_conversation(self, mock_app):
"""History APIs are available when the run has a conversation scope."""
mock_event = MagicMock()
mock_event.conversation_id = 'conv_001'
mock_event.thread_id = None
mock_descriptor = MagicMock()
mock_descriptor.permissions = {
'history': ['page', 'search'],
}
mock_descriptor = make_descriptor()
binding = AgentBinding(
binding_id='binding_001',
@@ -268,21 +283,16 @@ class TestContextAccessOtherAPIs:
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# History APIs enabled based on permissions
assert context_access['available_apis']['history_page'] is True
assert context_access['available_apis']['history_search'] is True
@pytest.mark.asyncio
async def test_event_apis_based_on_permissions(self, mock_app):
"""Event APIs availability based on runner permissions."""
async def test_event_apis_enabled_by_default(self, mock_app):
"""Event APIs are available based on current run scope."""
mock_event = MagicMock()
mock_event.conversation_id = 'conv_001'
mock_event.thread_id = None
mock_descriptor = MagicMock()
mock_descriptor.permissions = {
'events': ['get', 'page'],
}
mock_descriptor = make_descriptor()
binding = AgentBinding(
binding_id='binding_001',
@@ -296,21 +306,16 @@ class TestContextAccessOtherAPIs:
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# Event APIs enabled based on permissions
assert context_access['available_apis']['event_get'] is True
assert context_access['available_apis']['event_page'] is True
@pytest.mark.asyncio
async def test_artifact_apis_based_on_permissions(self, mock_app):
"""Artifact APIs availability based on runner permissions."""
async def test_artifact_apis_enabled_by_default(self, mock_app):
"""Artifact APIs are available based on current run scope."""
mock_event = MagicMock()
mock_event.conversation_id = 'conv_001'
mock_event.thread_id = None
mock_descriptor = MagicMock()
mock_descriptor.permissions = {
'artifacts': ['metadata', 'read'],
}
mock_descriptor = make_descriptor()
binding = AgentBinding(
binding_id='binding_001',
@@ -324,19 +329,16 @@ class TestContextAccessOtherAPIs:
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# Artifact APIs enabled based on permissions
assert context_access['available_apis']['artifact_metadata'] is True
assert context_access['available_apis']['artifact_read'] is True
@pytest.mark.asyncio
async def test_no_permissions_all_apis_disabled(self, mock_app):
"""All pull APIs disabled when permissions are empty."""
async def test_conversation_required_apis_disabled_without_conversation(self, mock_app):
"""Conversation-scoped APIs are disabled when the run has no conversation."""
mock_event = MagicMock()
mock_event.conversation_id = 'conv_001'
mock_event.conversation_id = None
mock_event.thread_id = None
mock_descriptor = MagicMock()
mock_descriptor.permissions = {} # No permissions
mock_descriptor = make_descriptor()
binding = AgentBinding(
binding_id='binding_001',
@@ -350,11 +352,37 @@ class TestContextAccessOtherAPIs:
# Real call
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
# All pull APIs should be disabled
assert context_access['available_apis']['history_page'] is False
assert context_access['available_apis']['history_search'] is False
assert context_access['available_apis']['event_get'] is True
assert context_access['available_apis']['event_page'] is False
assert context_access['available_apis']['artifact_metadata'] is True
assert context_access['available_apis']['artifact_read'] is True
assert context_access['available_apis']['state'] is False
@pytest.mark.asyncio
async def test_manifest_permissions_disable_context_apis(self, mock_app):
"""Pull APIs are disabled when manifest permissions omit them."""
mock_event = MagicMock()
mock_event.conversation_id = 'conv_001'
mock_event.thread_id = None
mock_descriptor = make_descriptor(permissions={})
binding = AgentBinding(
binding_id='binding_001',
runner_id='plugin:test/runner/default',
scope=BindingScope(scope_type='agent', scope_id='conv_001'),
state_policy=StatePolicy(enable_state=False, state_scopes=[]),
)
builder = AgentRunContextBuilder(mock_app)
context_access = await builder._build_context_access(mock_event, mock_descriptor, binding)
assert context_access['available_apis']['history_page'] is False
assert context_access['available_apis']['history_search'] is False
assert context_access['available_apis']['event_get'] is False
assert context_access['available_apis']['event_page'] is False
assert context_access['available_apis']['artifact_metadata'] is False
assert context_access['available_apis']['artifact_read'] is False
assert context_access['available_apis']['state'] is False
assert context_access['available_apis']['storage'] is False

View File

@@ -18,6 +18,7 @@ from langbot.pkg.agent.runner.context_builder import (
AgentRunContextBuilder,
AgentResources as BuilderResources,
)
from langbot.pkg.agent.runner.descriptor import AgentRunnerDescriptor
from langbot.pkg.agent.runner.host_models import AgentEventEnvelope, AgentBinding, BindingScope
from langbot.pkg.core import app
@@ -88,13 +89,20 @@ class TestContextValidation:
def _make_descriptor(self):
"""Create a mock runner descriptor."""
descriptor = MagicMock()
descriptor.id = "plugin:test/plugin/runner"
descriptor.permissions = {
'history': ['page', 'search'],
'events': ['get', 'page'],
}
return descriptor
return AgentRunnerDescriptor(
id="plugin:test/plugin/runner",
source="plugin",
label={"en_US": "Test Runner"},
plugin_author="test",
plugin_name="plugin",
runner_name="runner",
permissions={
"history": ["page", "search"],
"events": ["get", "page"],
"artifacts": ["metadata", "read"],
"storage": ["plugin", "workspace"],
},
)
@pytest.mark.asyncio
async def test_build_context_from_event_validates(self):

View File

@@ -23,12 +23,6 @@ from langbot_plugin.api.entities.builtin.agent_runner.result import (
AgentRunResult,
AgentRunResultType,
)
from langbot_plugin.api.entities.builtin.agent_runner.capabilities import (
AgentRunnerCapabilities,
)
from langbot_plugin.api.entities.builtin.agent_runner.permissions import (
AgentRunnerPermissions,
)
# Import LangBot host models
from langbot.pkg.agent.runner.query_entry_adapter import QueryEntryAdapter
@@ -57,6 +51,7 @@ class TestQueryToEventEnvelope:
assert event.input is not None
assert event.input.text == "Hello world"
assert "message_chain" not in event.input.model_dump()
def test_query_to_event_conversation(self, mock_query):
"""Test conversation context extraction."""
@@ -232,43 +227,6 @@ class TestHostManagedHistoryNotInProtocol:
assert "messages" not in ctx_fields
class TestSDKCapabilitiesProtocolV1:
"""Test SDK capabilities for Protocol v1."""
def test_self_managed_context_default_true(self):
"""Test self_managed_context defaults to True for Protocol v1."""
caps = AgentRunnerCapabilities()
assert caps.self_managed_context is True
def test_event_context_default_true(self):
"""Test event_context defaults to True for Protocol v1."""
caps = AgentRunnerCapabilities()
assert caps.event_context is True
class TestSDKPermissionsProtocolV1:
"""Test SDK permissions for Protocol v1."""
def test_permissions_new_fields(self):
"""Test new permission fields for Protocol v1."""
perms = AgentRunnerPermissions(
models=["invoke", "stream", "rerank"],
tools=["detail", "call"],
knowledge_bases=["list", "retrieve"],
history=["page", "search"],
events=["get", "page"],
artifacts=["metadata", "read"],
storage=["plugin", "workspace", "binding"],
)
assert perms.history == ["page", "search"]
assert perms.events == ["get", "page"]
assert perms.artifacts == ["metadata", "read"]
assert perms.storage == ["plugin", "workspace", "binding"]
class TestSDKResultProtocolV1:
"""Test SDK AgentRunResult for Protocol v1."""

View File

@@ -64,7 +64,7 @@ async def _register_session(
*,
run_id='run_1',
conversation_id='conv_1',
permissions=None,
available_apis=None,
):
await session_registry.register(
run_id=run_id,
@@ -73,13 +73,13 @@ async def _register_session(
plugin_identity='test/runner',
resources=make_resources(),
conversation_id=conversation_id,
permissions=permissions or {},
available_apis=available_apis or {},
)
@pytest.mark.asyncio
async def test_history_page_requires_manifest_permission(session_registry, db_engine):
await _register_session(session_registry, permissions={'history': []})
async def test_history_page_requires_runtime_capability(session_registry, db_engine):
await _register_session(session_registry, available_apis={'history_page': False})
handler = _handler(db_engine, session_registry)
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
@@ -94,7 +94,7 @@ async def test_history_page_requires_manifest_permission(session_registry, db_en
@pytest.mark.asyncio
async def test_history_page_rejects_cross_conversation(session_registry, db_engine):
await _register_session(session_registry, permissions={'history': ['page']})
await _register_session(session_registry, available_apis={'history_page': True})
handler = _handler(db_engine, session_registry)
history_page = handler.actions[PluginToRuntimeAction.HISTORY_PAGE.value]
@@ -110,7 +110,7 @@ async def test_history_page_rejects_cross_conversation(session_registry, db_engi
@pytest.mark.asyncio
async def test_history_search_rejects_filter_conversation_override(session_registry, db_engine):
await _register_session(session_registry, permissions={'history': ['search']})
await _register_session(session_registry, available_apis={'history_search': True})
handler = _handler(db_engine, session_registry)
history_search = handler.actions[PluginToRuntimeAction.HISTORY_SEARCH.value]
@@ -126,8 +126,8 @@ async def test_history_search_rejects_filter_conversation_override(session_regis
@pytest.mark.asyncio
async def test_event_page_requires_manifest_permission(session_registry, db_engine):
await _register_session(session_registry, permissions={'events': []})
async def test_event_page_requires_runtime_capability(session_registry, db_engine):
await _register_session(session_registry, available_apis={'event_page': False})
handler = _handler(db_engine, session_registry)
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
@@ -142,7 +142,7 @@ async def test_event_page_requires_manifest_permission(session_registry, db_engi
@pytest.mark.asyncio
async def test_event_page_rejects_cross_conversation(session_registry, db_engine):
await _register_session(session_registry, permissions={'events': ['page']})
await _register_session(session_registry, available_apis={'event_page': True})
handler = _handler(db_engine, session_registry)
event_page = handler.actions[PluginToRuntimeAction.EVENT_PAGE.value]
@@ -158,7 +158,7 @@ async def test_event_page_rejects_cross_conversation(session_registry, db_engine
@pytest.mark.asyncio
async def test_event_get_returns_sdk_record_projection(session_registry, db_engine):
await _register_session(session_registry, permissions={'events': ['get']})
await _register_session(session_registry, available_apis={'event_get': True})
store = EventLogStore(db_engine)
event_id = await store.append_event(
event_id='evt_projection_1',
@@ -193,7 +193,7 @@ async def test_event_get_returns_sdk_record_projection(session_registry, db_engi
@pytest.mark.asyncio
async def test_event_page_returns_sdk_page_projection(session_registry, db_engine):
await _register_session(session_registry, permissions={'events': ['page']})
await _register_session(session_registry, available_apis={'event_page': True})
store = EventLogStore(db_engine)
await store.append_event(
event_id='evt_projection_page_1',

View File

@@ -159,17 +159,19 @@ def make_descriptor() -> AgentRunnerDescriptor:
"knowledge_retrieval": True,
"skill_authoring": True,
},
permissions={
"models": ["invoke", "stream"],
"tools": ["detail", "call"],
"knowledge_bases": ["list", "retrieve"],
"history": ["page", "search"],
"events": ["get", "page"],
"artifacts": ["metadata", "read"],
"storage": ["plugin"],
},
config_schema=[
{"name": "model", "type": "model-fallback-selector"},
{"name": "knowledge-bases", "type": "knowledge-base-multi-selector", "default": []},
],
permissions={
"models": ["invoke", "stream"],
"tools": ["list", "detail", "call"],
"knowledge_bases": ["list", "retrieve"],
"storage": ["plugin"],
"files": [],
},
)

View File

@@ -13,13 +13,23 @@ from langbot.pkg.agent.runner.resource_builder import AgentResourceBuilder
RUNNER_ID = 'plugin:test/runner/default'
FULL_PERMISSIONS = {
'models': ['invoke', 'stream', 'rerank'],
'tools': ['detail', 'call'],
'knowledge_bases': ['list', 'retrieve'],
'history': ['page', 'search'],
'events': ['get', 'page'],
'artifacts': ['metadata', 'read'],
'storage': ['plugin', 'workspace'],
'files': ['config', 'knowledge'],
}
def make_descriptor(
*,
permissions: dict | None = None,
config_schema: list[dict] | None = None,
capabilities: dict | None = None,
permissions: dict | None = None,
) -> AgentRunnerDescriptor:
return AgentRunnerDescriptor(
id=RUNNER_ID,
@@ -29,7 +39,7 @@ def make_descriptor(
plugin_name='runner',
runner_name='default',
capabilities=capabilities or {},
permissions=permissions or {'models': ['invoke', 'stream']},
permissions=permissions if permissions is not None else FULL_PERMISSIONS,
config_schema=config_schema or [],
)
@@ -113,7 +123,6 @@ async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app
app.model_mgr.get_model_by_uuid = AsyncMock(side_effect=get_model_by_uuid)
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(side_effect=get_rerank_model_by_uuid)
descriptor = make_descriptor(
permissions={'models': ['invoke', 'stream', 'rerank']},
config_schema=[
{'name': 'model', 'type': 'model-fallback-selector'},
{'name': 'aux-model', 'type': 'llm-model-selector'},
@@ -137,16 +146,16 @@ async def test_build_models_authorizes_config_declared_llm_and_rerank_models(app
@pytest.mark.asyncio
async def test_build_models_still_honors_manifest_permissions(app):
"""Config-selected models should not bypass runner manifest permissions."""
async def test_build_models_from_config_without_manifest_acl(app):
"""Config-selected models are not projected without manifest model permissions."""
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(return_value=make_model(model_type='rerank'))
descriptor = make_descriptor(
permissions={'models': []},
config_schema=[
{'name': 'model', 'type': 'model-fallback-selector'},
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
],
permissions={},
)
query = make_query({
'model': {'primary': 'primary', 'fallbacks': ['fallback']},
@@ -156,19 +165,16 @@ async def test_build_models_still_honors_manifest_permissions(app):
resources = await build_resources(app, query, descriptor)
assert resources['models'] == []
app.model_mgr.get_model_by_uuid.assert_not_awaited()
app.model_mgr.get_rerank_model_by_uuid.assert_not_awaited()
@pytest.mark.asyncio
async def test_build_models_authorizes_rerank_only_runner(app):
"""A rerank-only runner should receive config-selected rerank models."""
async def test_build_models_authorizes_rerank_and_llm_refs_from_config(app):
"""Config-selected model references are projected regardless of method granularity."""
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
return_value=make_model(model_type='rerank', provider='rerank-provider')
)
descriptor = make_descriptor(
permissions={'models': ['rerank']},
config_schema=[
{'name': 'model', 'type': 'llm-model-selector'},
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
@@ -181,10 +187,39 @@ async def test_build_models_authorizes_rerank_only_runner(app):
resources = await build_resources(app, query, descriptor)
assert resources['models'] == [
{'model_id': 'llm', 'model_type': 'llm', 'provider': 'test-provider'},
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider'},
]
@pytest.mark.asyncio
async def test_build_models_manifest_permission_narrows_binding(app):
"""Manifest model permissions narrower than binding should remove LLM grants."""
app.model_mgr.get_model_by_uuid = AsyncMock(return_value=make_model())
app.model_mgr.get_rerank_model_by_uuid = AsyncMock(
return_value=make_model(model_type='rerank', provider='rerank-provider')
)
descriptor = make_descriptor(
config_schema=[
{'name': 'model', 'type': 'llm-model-selector'},
{'name': 'rerank-model', 'type': 'rerank-model-selector'},
],
permissions={
**FULL_PERMISSIONS,
'models': ['rerank'],
},
)
query = make_query({
'model': 'llm',
'rerank-model': 'rerank',
})
resources = await build_resources(app, query, descriptor)
assert resources['models'] == [
{'model_id': 'rerank', 'model_type': 'rerank', 'provider': 'rerank-provider'},
]
app.model_mgr.get_model_by_uuid.assert_not_awaited()
@pytest.mark.asyncio
@@ -212,10 +247,7 @@ async def test_build_models_deduplicates_query_and_config_models(app):
async def test_build_tools_authorizes_query_declared_tools(app):
"""Tools discovered by Pipeline preprocessing become run-scoped authorized resources."""
descriptor = make_descriptor(
permissions={
'models': [],
'tools': ['detail', 'call'],
},
capabilities={'tool_calling': True},
)
query = make_query(
{},
@@ -241,14 +273,32 @@ async def test_build_tools_authorizes_query_declared_tools(app):
]
@pytest.mark.asyncio
async def test_build_tools_manifest_permission_denies_binding_tools(app):
"""Binding tool grants should be removed when manifest does not request tools."""
descriptor = make_descriptor(
capabilities={'tool_calling': True},
permissions={
**FULL_PERMISSIONS,
'tools': [],
},
)
query = make_query(
{},
use_funcs=[
{'name': 'qa_plugin_echo', 'description': 'Echo test tool'},
],
)
resources = await build_resources(app, query, descriptor)
assert resources['tools'] == []
@pytest.mark.asyncio
async def test_build_knowledge_bases_unions_config_and_policy_grants(app):
descriptor = make_descriptor(
capabilities={'knowledge_retrieval': True},
permissions={
'models': [],
'knowledge_bases': ['retrieve'],
},
config_schema=[
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'},
],
@@ -273,3 +323,43 @@ async def test_build_knowledge_bases_unions_config_and_policy_grants(app):
{'kb_id': 'kb_config', 'kb_name': 'name-kb_config', 'kb_type': 'default'},
{'kb_id': 'kb_policy', 'kb_name': 'name-kb_policy', 'kb_type': 'default'},
]
@pytest.mark.asyncio
async def test_build_knowledge_bases_manifest_permission_denies_binding_kbs(app):
descriptor = make_descriptor(
capabilities={'knowledge_retrieval': True},
permissions={
**FULL_PERMISSIONS,
'knowledge_bases': [],
},
config_schema=[
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector'},
],
)
query = make_query(
{'knowledge-bases': ['kb_config']},
variables={'_knowledge_base_uuids': ['kb_policy']},
)
resources = await build_resources(app, query, descriptor)
assert resources['knowledge_bases'] == []
@pytest.mark.asyncio
async def test_build_storage_intersects_manifest_and_binding_policy(app):
descriptor = make_descriptor(
permissions={
**FULL_PERMISSIONS,
'storage': ['plugin'],
},
)
query = make_query({})
resources = await build_resources(app, query, descriptor)
assert resources['storage'] == {
'plugin_storage': True,
'workspace_storage': False,
}

View File

@@ -14,12 +14,15 @@ class FakeApplication:
"""Fake Application for testing."""
def __init__(self):
class FakeLogger:
def __init__(self):
self.warnings = []
def info(self, msg):
pass
def debug(self, msg):
pass
def warning(self, msg):
pass
self.warnings.append(msg)
def error(self, msg):
pass
@@ -67,7 +70,7 @@ class TestNormalizeMessageDelta:
@pytest.mark.asyncio
async def test_normalize_message_delta_missing_chunk(self):
"""Normalize message.delta without chunk data."""
"""Invalid message.delta payload is dropped."""
normalizer = AgentResultNormalizer(FakeApplication())
descriptor = make_descriptor()
@@ -76,10 +79,9 @@ class TestNormalizeMessageDelta:
'data': {},
}
with pytest.raises(RunnerProtocolError) as exc_info:
await normalizer.normalize(result_dict, descriptor)
result = await normalizer.normalize(result_dict, descriptor)
assert 'missing chunk data' in str(exc_info.value)
assert result is None
class TestNormalizeMessageCompleted:
@@ -110,7 +112,7 @@ class TestNormalizeMessageCompleted:
@pytest.mark.asyncio
async def test_normalize_message_completed_missing_message(self):
"""Normalize message.completed without message data."""
"""Invalid message.completed payload is dropped."""
normalizer = AgentResultNormalizer(FakeApplication())
descriptor = make_descriptor()
@@ -119,10 +121,9 @@ class TestNormalizeMessageCompleted:
'data': {},
}
with pytest.raises(RunnerProtocolError) as exc_info:
await normalizer.normalize(result_dict, descriptor)
result = await normalizer.normalize(result_dict, descriptor)
assert 'missing message data' in str(exc_info.value)
assert result is None
class TestNormalizeRunCompleted:
@@ -260,13 +261,57 @@ class TestNormalizeNonMessageResults:
'type': 'action.requested',
'data': {
'action': 'platform.message.edit',
'parameters': {},
'payload': {},
},
}
result = await normalizer.normalize(result_dict, descriptor)
assert result is None
@pytest.mark.asyncio
async def test_invalid_state_updated_payload_is_dropped(self):
"""Invalid state.updated payload returns None with a warning."""
app = FakeApplication()
normalizer = AgentResultNormalizer(app)
descriptor = make_descriptor()
result = await normalizer.normalize(
{
'type': 'state.updated',
'data': {
'scope': 'invalid',
'key': 'k',
'value': 'v',
},
},
descriptor,
)
assert result is None
assert app.logger.warnings
@pytest.mark.asyncio
async def test_invalid_artifact_created_payload_is_dropped(self):
"""Invalid artifact.created payload returns None with a warning."""
app = FakeApplication()
normalizer = AgentResultNormalizer(app)
descriptor = make_descriptor()
result = await normalizer.normalize(
{
'type': 'artifact.created',
'data': {
'artifact_id': 'artifact-1',
'artifact_type': 'file',
'content_base64': 'not base64',
},
},
descriptor,
)
assert result is None
assert app.logger.warnings
class TestNormalizeInvalidResults:
"""Tests for handling invalid results."""

View File

@@ -63,7 +63,7 @@ class TestSessionRegistryBasic:
query_id=1,
plugin_identity='test/my-runner',
resources=resources,
permissions={'models': ['invoke']},
available_apis={'history_page': True},
conversation_id='conv_001',
)
@@ -74,7 +74,7 @@ class TestSessionRegistryBasic:
assert session is not None
authorization = session['authorization']
assert authorization['conversation_id'] == 'conv_001'
assert authorization['permissions'] == {'models': ['invoke']}
assert authorization['available_apis'] == {'history_page': True}
assert registry.is_resource_allowed(session, 'model', 'model_001') is True
assert registry.is_resource_allowed(session, 'model', 'model_late') is False
assert registry.is_resource_allowed(session, 'storage', 'workspace') is False

View File

@@ -14,6 +14,23 @@ from tests.factories import FakeApp
DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default'
_current_runner_class = None
def _default_runner_class():
from langbot_plugin.api.entities.builtin.provider.message import Message
class DefaultRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='fake response')
return DefaultRunner
def runner_pipeline_config(output_misc: dict) -> dict:
@@ -47,21 +64,8 @@ def mock_circular_import_chain():
make_pipeline_handler_import_mocks,
get_handler_modules_to_clear,
)
from langbot_plugin.api.entities.builtin.provider.message import Message
mocks = make_pipeline_handler_import_mocks()
# Create a default runner that yields a simple response
class DefaultRunner:
name = 'local-agent'
def __init__(self, app, config):
self.app = app
self.config = config
async def run(self, query):
yield Message(role='assistant', content='fake response')
mocks['langbot.pkg.provider.runner'].preregistered_runners = [DefaultRunner]
clear = get_handler_modules_to_clear('chat')
with isolated_sys_modules(mocks=mocks, clear=clear):
@@ -75,9 +79,7 @@ def fake_app():
class ProviderRunnerBackedOrchestrator:
async def run_from_query(self, query):
import sys
runner_class = sys.modules['langbot.pkg.provider.runner'].preregistered_runners[0]
runner_class = _current_runner_class or _default_runner_class()
runner = runner_class(app, {})
async for result in runner.run(query):
yield result
@@ -103,10 +105,15 @@ def mock_event_ctx():
@pytest.fixture
def set_runner():
"""Factory fixture to set a custom runner for tests."""
global _current_runner_class
previous = _current_runner_class
def _set_runner(runner_class):
import sys
sys.modules['langbot.pkg.provider.runner'].preregistered_runners = [runner_class]
return _set_runner
global _current_runner_class
_current_runner_class = runner_class
yield _set_runner
_current_runner_class = previous
# ============== CACHED LAZY IMPORTS ==============

View File

@@ -1,353 +0,0 @@
"""
Unit tests for N8nServiceAPIRunner._process_response
Tests cover four scenarios:
- Stream adapter + n8n stream format (type:item/end)
- Stream adapter + n8n plain JSON
- Non-stream adapter + n8n stream format
- Non-stream adapter + n8n plain JSON
"""
from __future__ import annotations
import json
import sys
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
import langbot_plugin.api.entities.builtin.provider.message as provider_message
# Break the circular import chain while importing n8nsvapi:
# n8nsvapi → runner → app → pipelinemgr → all runners → runner (partially init)
# The stubs are restored in a ``finally`` block so this module does NOT pollute
# sys.modules for other test modules (e.g. ones importing the real
# LocalAgentRunner, which would otherwise inherit ``object`` and break).
# Mirrors master's intent but uses try/finally so a raised import doesn't
# leave the global namespace in a stubbed state, and includes
# ``langbot.pkg.utils.httpclient`` which master didn't stub.
_runner_stub = MagicMock()
_runner_stub.runner_class = lambda name: (lambda cls: cls) # no-op decorator
_runner_stub.RequestRunner = object
_import_stubs = {
'langbot.pkg.provider.runner': _runner_stub,
'langbot.pkg.core.app': MagicMock(),
'langbot.pkg.utils.httpclient': MagicMock(),
}
_saved_modules = {name: sys.modules.get(name) for name in _import_stubs}
for _name, _stub in _import_stubs.items():
sys.modules[_name] = _stub
try:
from langbot.pkg.provider.runners.n8nsvapi import N8nServiceAPIRunner
finally:
for _name, _original in _saved_modules.items():
if _original is None:
sys.modules.pop(_name, None)
else:
sys.modules[_name] = _original
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_runner(output_key: str = 'response') -> N8nServiceAPIRunner:
ap = Mock()
ap.logger = Mock()
pipeline_config = {
'ai': {
'n8n-service-api': {
'webhook-url': 'http://test-n8n/webhook',
'output-key': output_key,
'auth-type': 'none',
}
}
}
return N8nServiceAPIRunner(ap, pipeline_config)
def make_mock_response(chunks: list[bytes | str], status: int = 200):
"""Build a minimal aiohttp.ClientResponse mock with iter_chunked support."""
response = Mock()
response.status = status
async def iter_chunked(size):
for chunk in chunks:
yield chunk
response.content = Mock()
response.content.iter_chunked = iter_chunked
return response
async def collect_chunks(runner: N8nServiceAPIRunner, chunks: list[bytes | str]):
"""Run _process_response and collect all yielded MessageChunks."""
response = make_mock_response(chunks)
result = []
async for chunk in runner._process_response(response):
result.append(chunk)
return result
# ---------------------------------------------------------------------------
# _process_response: stream format (type:item/end)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stream_format_single_item():
"""Single item + end in one chunk yields final chunk with full content."""
runner = make_runner()
data = b'{"type":"item","content":"hello"}{"type":"end"}'
chunks = await collect_chunks(runner, [data])
assert len(chunks) == 1
assert chunks[0].is_final is True
assert chunks[0].content == 'hello'
assert chunks[0].msg_sequence == 1
@pytest.mark.asyncio
async def test_stream_format_multi_item_accumulates():
"""Multiple items accumulate into full_content."""
runner = make_runner()
chunks_data = [
b'{"type":"item","content":"foo"}',
b'{"type":"item","content":"bar"}',
b'{"type":"end"}',
]
chunks = await collect_chunks(runner, chunks_data)
assert len(chunks) == 1
assert chunks[0].is_final is True
assert chunks[0].content == 'foobar'
assert chunks[0].msg_sequence == 1
@pytest.mark.asyncio
async def test_stream_format_batches_every_8_items():
"""Every 8th item triggers an intermediate yield before the final."""
runner = make_runner()
items = [f'{{"type":"item","content":"{i}"}}' for i in range(8)]
items.append('{"type":"end"}')
data = ''.join(items).encode()
chunks = await collect_chunks(runner, [data])
assert len(chunks) == 2
assert chunks[0].is_final is False
assert chunks[0].content == '01234567'
assert chunks[0].msg_sequence == 1
assert chunks[1].is_final is True
assert chunks[1].content == '01234567'
assert chunks[1].msg_sequence == 2
@pytest.mark.asyncio
async def test_stream_format_split_across_network_chunks():
"""JSON split across multiple network chunks is reassembled correctly."""
runner = make_runner()
part1 = b'{"type":"item","con'
part2 = b'tent":"world"}{"type":"end"}'
chunks = await collect_chunks(runner, [part1, part2])
assert len(chunks) == 1
assert chunks[0].is_final is True
assert chunks[0].content == 'world'
@pytest.mark.asyncio
async def test_stream_format_no_spurious_empty_yield():
"""chunk_idx==0 guard prevents spurious empty yield before any item is received."""
runner = make_runner()
# Send some non-stream JSON first, then stream
data = b'{"type":"item","content":"x"}{"type":"end"}'
chunks = await collect_chunks(runner, [data])
assert len(chunks) == 1
assert chunks[0].content == 'x'
# ---------------------------------------------------------------------------
# _process_response: plain JSON fallback
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plain_json_with_output_key():
"""Plain JSON with matching output_key extracts value via output_key."""
runner = make_runner(output_key='response')
data = json.dumps({'response': 'hello world'}).encode()
chunks = await collect_chunks(runner, [data])
assert len(chunks) == 1
assert chunks[0].is_final is True
assert chunks[0].content == 'hello world'
@pytest.mark.asyncio
async def test_plain_json_output_key_not_found():
"""Plain JSON without output_key falls back to entire JSON string."""
runner = make_runner(output_key='response')
payload = {'other_key': 'hello'}
data = json.dumps(payload).encode()
chunks = await collect_chunks(runner, [data])
assert len(chunks) == 1
assert chunks[0].is_final is True
assert json.loads(chunks[0].content) == payload
@pytest.mark.asyncio
async def test_plain_json_output_key_empty_string():
"""output_key present but value is empty string — returns empty string, not whole JSON."""
runner = make_runner(output_key='response')
data = json.dumps({'response': ''}).encode()
chunks = await collect_chunks(runner, [data])
assert len(chunks) == 1
assert chunks[0].is_final is True
assert chunks[0].content == ''
@pytest.mark.asyncio
async def test_plain_json_non_dict_response():
"""Plain JSON array falls back to raw text."""
runner = make_runner()
data = b'["a", "b"]'
chunks = await collect_chunks(runner, [data])
assert len(chunks) == 1
assert chunks[0].is_final is True
assert chunks[0].content == '["a", "b"]'
@pytest.mark.asyncio
async def test_invalid_json_returns_raw_text():
"""Non-JSON response returns raw text as-is."""
runner = make_runner()
data = b'plain text response'
chunks = await collect_chunks(runner, [data])
assert len(chunks) == 1
assert chunks[0].is_final is True
assert chunks[0].content == 'plain text response'
# ---------------------------------------------------------------------------
# _call_webhook: output type depends on is_stream
# ---------------------------------------------------------------------------
def make_query(is_stream: bool):
"""Build a minimal Query mock."""
query = Mock()
query.adapter = AsyncMock()
query.adapter.is_stream_output_supported = AsyncMock(return_value=is_stream)
session = Mock()
session.using_conversation = Mock()
session.using_conversation.uuid = 'test-uuid'
session.launcher_type = Mock()
session.launcher_type.value = 'person'
session.launcher_id = '12345'
query.session = session
query.user_message = Mock()
query.user_message.content = 'hi'
query.variables = {}
return query
def make_http_session_mock(response_bytes: bytes, status: int = 200):
"""Mock httpclient.get_session() returning a session whose post() yields response_bytes."""
mock_response = make_mock_response([response_bytes], status=status)
mock_response.status = status
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_session = Mock()
mock_session.post = Mock(return_value=mock_cm)
return mock_session
@pytest.mark.asyncio
async def test_call_webhook_nonstream_adapter_plain_json():
"""Non-stream adapter + plain JSON → single Message with output_key value."""
runner = make_runner(output_key='response')
query = make_query(is_stream=False)
http_session = make_http_session_mock(json.dumps({'response': 'result text'}).encode())
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
results = []
async for msg in runner._call_webhook(query):
results.append(msg)
assert len(results) == 1
assert isinstance(results[0], provider_message.Message)
assert results[0].content == 'result text'
@pytest.mark.asyncio
async def test_call_webhook_stream_adapter_stream_format():
"""Stream adapter + stream format → MessageChunks, last is_final."""
runner = make_runner()
query = make_query(is_stream=True)
data = b'{"type":"item","content":"hi"}{"type":"end"}'
http_session = make_http_session_mock(data)
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
results = []
async for msg in runner._call_webhook(query):
results.append(msg)
assert all(isinstance(r, provider_message.MessageChunk) for r in results)
assert results[-1].is_final is True
assert results[-1].content == 'hi'
@pytest.mark.asyncio
async def test_call_webhook_stream_adapter_plain_json():
"""Stream adapter + plain JSON → single MessageChunk with is_final=True."""
runner = make_runner(output_key='response')
query = make_query(is_stream=True)
data = json.dumps({'response': 'fallback'}).encode()
http_session = make_http_session_mock(data)
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
results = []
async for msg in runner._call_webhook(query):
results.append(msg)
assert all(isinstance(r, provider_message.MessageChunk) for r in results)
assert results[-1].is_final is True
assert results[-1].content == 'fallback'
@pytest.mark.asyncio
async def test_call_webhook_nonstream_adapter_stream_format():
"""Non-stream adapter + stream format → single Message with accumulated content."""
runner = make_runner()
query = make_query(is_stream=False)
data = b'{"type":"item","content":"foo"}{"type":"item","content":"bar"}{"type":"end"}'
http_session = make_http_session_mock(data)
with patch('langbot.pkg.provider.runners.n8nsvapi.httpclient.get_session', return_value=http_session):
results = []
async for msg in runner._call_webhook(query):
results.append(msg)
assert len(results) == 1
assert isinstance(results[0], provider_message.Message)
assert results[0].content == 'foobar'

View File

@@ -73,8 +73,8 @@ def make_host_model_runner_descriptor(
'skill_authoring': skill_authoring,
},
permissions={
'models': ['list', 'invoke', 'stream'],
'tools': ['list', 'detail', 'call'],
'models': ['invoke', 'stream'],
'tools': ['detail', 'call'],
'knowledge_bases': ['list', 'retrieve'],
},
)

View File

@@ -1,169 +0,0 @@
"""Tests for DifyServiceAPIRunner pure utility methods.
Tests the helper methods that don't require real Dify API calls.
"""
from __future__ import annotations
import pytest
class TestDifyExtractTextOutput:
"""Tests for _extract_dify_text_output method."""
def _create_runner(self):
"""Create runner instance."""
from unittest.mock import MagicMock
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
mock_app = MagicMock()
pipeline_config = {
'ai': {
'dify-service-api': {
'app-type': 'chat',
'api-key': 'test-key',
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
runner.dify_client = MagicMock()
return runner
def test_extract_none_value(self):
"""None returns empty string."""
runner = self._create_runner()
result = runner._extract_dify_text_output(None)
assert result == ''
def test_extract_string_value(self):
"""Plain string is returned."""
runner = self._create_runner()
result = runner._extract_dify_text_output('plain text')
assert result == 'plain text'
def test_extract_dict_with_content(self):
"""Dict with 'content' key extracts content."""
runner = self._create_runner()
result = runner._extract_dify_text_output({'content': 'extracted content'})
assert result == 'extracted content'
def test_extract_dict_without_content(self):
"""Dict without 'content' key is JSON dumped."""
runner = self._create_runner()
result = runner._extract_dify_text_output({'key': 'value'})
assert 'key' in result
assert 'value' in result
def test_extract_json_string_with_content(self):
"""JSON string with 'content' key extracts content."""
runner = self._create_runner()
result = runner._extract_dify_text_output('{"content": "json content"}')
assert result == 'json content'
def test_extract_json_string_without_content(self):
"""JSON string without 'content' key returns original."""
runner = self._create_runner()
result = runner._extract_dify_text_output('{"other": "value"}')
assert '{"other": "value"}' in result
def test_extract_whitespace_string(self):
"""Whitespace string returns empty."""
runner = self._create_runner()
result = runner._extract_dify_text_output(' ')
assert result == ''
class TestDifyRunnerConfigValidation:
"""Tests for runner config validation."""
def test_invalid_app_type_raises(self):
"""Invalid app-type raises DifyAPIError."""
from unittest.mock import MagicMock
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
from langbot.libs.dify_service_api.v1.errors import DifyAPIError
mock_app = MagicMock()
pipeline_config = {
'ai': {
'dify-service-api': {
'app-type': 'invalid-type',
'api-key': 'test',
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
}
with pytest.raises(DifyAPIError, match='不支持'):
DifyServiceAPIRunner(mock_app, pipeline_config)
def test_valid_app_types(self):
"""Valid app-types don't raise."""
from unittest.mock import MagicMock
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
mock_app = MagicMock()
for app_type in ['chat', 'agent', 'workflow']:
pipeline_config = {
'ai': {
'dify-service-api': {
'app-type': app_type,
'api-key': 'test',
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
# Should not raise
assert runner is not None
class TestDifyRunnerInit:
"""Tests for runner initialization."""
def test_runner_stores_config(self):
"""Runner stores pipeline_config."""
from unittest.mock import MagicMock
from langbot.pkg.provider.runners.difysvapi import DifyServiceAPIRunner
mock_app = MagicMock()
pipeline_config = {
'ai': {
'dify-service-api': {
'app-type': 'chat',
'api-key': 'test-key',
'base-url': 'https://api.dify.ai',
}
},
'output': {'misc': {}}
}
runner = DifyServiceAPIRunner(mock_app, pipeline_config)
assert runner.pipeline_config == pipeline_config
assert runner.ap == mock_app

View File

@@ -1,242 +0,0 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import pytest
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
class RecordingProvider:
def __init__(self):
self.requests: list[dict] = []
async def invoke_llm(self, query, model, messages, funcs, extra_args=None, remove_think=None):
self.requests.append(
{
'messages': list(messages),
'funcs': list(funcs),
'remove_think': remove_think,
}
)
if len(self.requests) == 1:
return provider_message.Message(
role='assistant',
content='Let me calculate that exactly.',
tool_calls=[
provider_message.ToolCall(
id='call-1',
type='function',
function=provider_message.FunctionCall(
name='exec',
arguments=json.dumps(
{'command': ("python - <<'PY'\nnums = [1, 2, 3, 4]\nprint(sum(nums) / len(nums))\nPY")}
),
),
)
],
)
tool_result = json.loads(messages[-1].content)
return provider_message.Message(
role='assistant',
content=f'The average is {tool_result["stdout"]}.',
)
class RecordingStreamProvider:
def __init__(self):
self.stream_requests: list[dict] = []
def invoke_llm_stream(self, query, model, messages, funcs, extra_args=None, remove_think=None):
self.stream_requests.append(
{
'messages': list(messages),
'funcs': list(funcs),
'remove_think': remove_think,
}
)
async def _stream():
if len(self.stream_requests) == 1:
yield provider_message.MessageChunk(
role='assistant',
tool_calls=[
provider_message.ToolCall(
id='call-1',
type='function',
function=provider_message.FunctionCall(
name='exec',
arguments=json.dumps({'command': "python -c 'print(1)'"}),
),
)
],
is_final=True,
)
return
yield provider_message.MessageChunk(
role='assistant',
content='Tool execution failed.',
is_final=True,
)
return _stream()
def make_query() -> pipeline_query.Query:
adapter = AsyncMock()
adapter.is_stream_output_supported = AsyncMock(return_value=False)
return pipeline_query.Query.model_construct(
query_id='avg-query',
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
message_chain=[],
message_event=None,
adapter=adapter,
pipeline_uuid='pipeline-uuid',
bot_uuid='bot-uuid',
pipeline_config={
'ai': {
'runner': {'runner': 'local-agent'},
'local-agent': {'model': {'primary': 'test-model-uuid', 'fallbacks': []}, 'prompt': 'test-prompt'},
},
'output': {'misc': {'remove-think': False}},
},
prompt=SimpleNamespace(messages=[]),
messages=[],
user_message=provider_message.Message(
role='user',
content='Please calculate the average of 1, 2, 3, and 4.',
),
use_funcs=[SimpleNamespace(name='exec')],
use_llm_model_uuid='test-model-uuid',
variables={},
)
@pytest.mark.asyncio
async def test_localagent_uses_exec_for_exact_calculation():
provider = RecordingProvider()
model = SimpleNamespace(
provider=provider,
model_entity=SimpleNamespace(
uuid='test-model-uuid',
name='test-model',
abilities=['func_call'],
extra_args={},
),
)
tool_manager = SimpleNamespace(
execute_func_call=AsyncMock(
return_value={
'session_id': 'avg-query',
'backend': 'podman',
'status': 'completed',
'ok': True,
'exit_code': 0,
'stdout': '2.5',
'stderr': '',
'duration_ms': 18,
}
)
)
app = SimpleNamespace(
logger=Mock(),
model_mgr=SimpleNamespace(get_model_by_uuid=AsyncMock(return_value=model)),
tool_mgr=tool_manager,
rag_mgr=SimpleNamespace(),
box_service=SimpleNamespace(
get_system_guidance=Mock(
return_value=(
'When the exec tool is available, use it for exact calculations, statistics, '
'structured data parsing, and code execution instead of estimating mentally. '
'Unless the user explicitly asks for the script, code, or implementation details, '
'do not include the generated script in the final answer. '
'A default workspace is mounted at /workspace for file tasks.'
)
),
),
skill_mgr=SimpleNamespace(
get_skills_for_pipeline=AsyncMock(return_value=[]),
detect_skill_activation=AsyncMock(return_value=None),
build_activation_prompt=Mock(return_value=None),
),
)
runner = LocalAgentRunner(app, pipeline_config={})
query = make_query()
results = [message async for message in runner.run(query)]
assert [message.role for message in results] == ['assistant', 'tool', 'assistant']
assert results[-1].content == 'The average is 2.5.'
tool_manager.execute_func_call.assert_awaited_once()
tool_name, tool_parameters = tool_manager.execute_func_call.await_args.args[:2]
assert tool_name == 'exec'
assert 'print(sum(nums) / len(nums))' in tool_parameters['command']
first_request = provider.requests[0]
assert any(
message.role == 'system'
and 'exec' in str(message.content)
and 'exact calculations' in str(message.content)
and 'Unless the user explicitly asks for the script' in str(message.content)
and '/workspace' in str(message.content)
for message in first_request['messages']
)
assert [tool.name for tool in first_request['funcs']] == ['exec']
@pytest.mark.asyncio
async def test_localagent_streaming_tool_error_yields_message_chunks():
provider = RecordingStreamProvider()
model = SimpleNamespace(
provider=provider,
model_entity=SimpleNamespace(
uuid='test-model-uuid',
name='test-model',
abilities=['func_call'],
extra_args={},
),
)
adapter = AsyncMock()
adapter.is_stream_output_supported = AsyncMock(return_value=True)
query = make_query()
query.adapter = adapter
app = SimpleNamespace(
logger=Mock(),
model_mgr=SimpleNamespace(get_model_by_uuid=AsyncMock(return_value=model)),
tool_mgr=SimpleNamespace(execute_func_call=AsyncMock(side_effect=RuntimeError('boom'))),
rag_mgr=SimpleNamespace(),
box_service=SimpleNamespace(
get_system_guidance=Mock(return_value='sandbox guidance'),
),
skill_mgr=SimpleNamespace(
get_skills_for_pipeline=AsyncMock(return_value=[]),
detect_skill_activation=AsyncMock(return_value=None),
build_activation_prompt=Mock(return_value=None),
),
)
runner = LocalAgentRunner(app, pipeline_config={})
results = [message async for message in runner.run(query)]
assert all(isinstance(message, provider_message.MessageChunk) for message in results)
assert any(message.role == 'tool' and message.content == 'err: boom' for message in results)

View File

@@ -21,7 +21,6 @@ from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
from langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl import ModelScopeChatCompletions
from langbot.pkg.provider.modelmgr.token import TokenManager
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
DEFAULT_RUNNER_ID = 'plugin:langbot/local-agent/default'
@@ -43,8 +42,8 @@ class FakeAgentRunnerRegistry:
],
capabilities={'tool_calling': True, 'knowledge_retrieval': True, 'multimodal_input': True},
permissions={
'models': ['list', 'invoke', 'stream'],
'tools': ['list', 'detail', 'call'],
'models': ['invoke', 'stream'],
'tools': ['detail', 'call'],
'knowledge_bases': ['list', 'retrieve'],
},
)
@@ -320,8 +319,3 @@ async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline()
processed_query = result.new_query
assert processed_query.use_llm_model_uuid == model_uuid
runner = SimpleNamespace(ap=ap, pipeline_config=pipeline_config)
candidates = await LocalAgentRunner._get_model_candidates(runner, processed_query)
assert [model.model_entity.uuid for model in candidates] == [model_uuid]

View File

@@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, Mock
import pytest
from langbot_plugin.api.entities.builtin.agent_runner.manifest import (
AgentRunnerCapabilities,
AgentRunnerPermissions,
)
from langbot_plugin.api.entities.builtin.pipeline.query import Query
from langbot_plugin.api.entities.builtin.platform.entities import Friend
from langbot_plugin.api.entities.builtin.platform.events import FriendMessage
@@ -24,22 +28,23 @@ class _FakeRunnerDescriptor:
{'name': 'knowledge-bases', 'type': 'knowledge-base-multi-selector', 'default': []},
]
permissions = {
'models': ['list', 'invoke', 'stream'],
'tools': ['list', 'detail', 'call'],
'models': ['invoke', 'stream'],
'tools': ['detail', 'call'],
'knowledge_bases': ['list', 'retrieve'],
}
capabilities = {
'tool_calling': True,
'knowledge_retrieval': True,
'multimodal_input': True,
'skill_authoring': True,
}
permissions = AgentRunnerPermissions.model_validate(permissions)
capabilities = AgentRunnerCapabilities(
tool_calling=True,
knowledge_retrieval=True,
multimodal_input=True,
skill_authoring=True,
)
def supports_tool_calling(self):
return self.capabilities.get('tool_calling', False)
return self.capabilities.tool_calling
def supports_knowledge_retrieval(self):
return self.capabilities.get('knowledge_retrieval', False)
return self.capabilities.knowledge_retrieval
def _make_query() -> Query: