mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-22 21:44:20 +00:00
feat(agent-runner): add plugin runner host integration
This commit is contained in:
@@ -7,6 +7,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
from langbot_plugin.entities.io.actions.enums import PluginToRuntimeAction, RuntimeToLangBotAction
|
||||
|
||||
|
||||
@@ -27,6 +28,22 @@ def compiled_params(statement):
|
||||
return statement.compile().params
|
||||
|
||||
|
||||
def make_agent_resources(
|
||||
models: list[dict] | None = None,
|
||||
tools: list[dict] | None = None,
|
||||
knowledge_bases: list[dict] | None = None,
|
||||
):
|
||||
"""Create a minimal AgentRun resources payload for run-scoped action tests."""
|
||||
return {
|
||||
'models': models or [],
|
||||
'tools': tools or [],
|
||||
'knowledge_bases': knowledge_bases or [],
|
||||
'files': [],
|
||||
'storage': {'plugin_storage': False, 'workspace_storage': False},
|
||||
'platform_capabilities': {},
|
||||
}
|
||||
|
||||
|
||||
class TestRagRerankAction:
|
||||
"""Tests for RAG rerank action handler."""
|
||||
|
||||
@@ -421,3 +438,433 @@ class TestHandlerQueryLookup:
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data == {'bot_uuid': 'test-bot-uuid'}
|
||||
|
||||
|
||||
class TestAgentRunProxyActions:
|
||||
"""Tests for AgentRunner proxy actions that need host Query semantics."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
mock_app = Mock()
|
||||
mock_app.logger = Mock()
|
||||
mock_app.query_pool = Mock()
|
||||
mock_app.query_pool.cached_queries = {}
|
||||
mock_app.model_mgr = Mock()
|
||||
mock_app.model_mgr.get_model_by_uuid = AsyncMock()
|
||||
mock_app.model_mgr.get_rerank_model_by_uuid = AsyncMock()
|
||||
mock_app.tool_mgr = Mock()
|
||||
mock_app.tool_mgr.execute_func_call = AsyncMock(return_value={'ok': True})
|
||||
return mock_app
|
||||
|
||||
@staticmethod
|
||||
def query(remove_think=True):
|
||||
return SimpleNamespace(
|
||||
pipeline_config={'output': {'misc': {'remove-think': remove_think}}},
|
||||
variables={},
|
||||
prompt=SimpleNamespace(
|
||||
messages=[provider_message.Message(role='system', content='effective prompt')]
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_prompt_returns_query_effective_prompt(self, app):
|
||||
"""GET_PROMPT returns the preprocessed Query prompt for the active run."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_get_prompt'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[900] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=900,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(),
|
||||
available_apis={'prompt_get': True},
|
||||
)
|
||||
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.GET_PROMPT.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data['prompt'][0]['role'] == 'system'
|
||||
assert response.data['prompt'][0]['content'] == 'effective prompt'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_restores_query_and_model_options(self, app):
|
||||
"""INVOKE_LLM passes Query, model extra_args and remove-think to provider."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_options'
|
||||
query = self.query(remove_think=True)
|
||||
app.query_pool.cached_queries[901] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=901,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_001'}]),
|
||||
)
|
||||
|
||||
provider = SimpleNamespace(
|
||||
invoke_llm=AsyncMock(return_value=provider_message.Message(role='assistant', content='ok')),
|
||||
)
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(
|
||||
abilities=['func_call'],
|
||||
extra_args={'temperature': 0.2, 'top_p': 0.8},
|
||||
),
|
||||
provider=provider,
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
'funcs': [{
|
||||
'name': 'search',
|
||||
'human_desc': 'Search',
|
||||
'description': 'Search',
|
||||
'parameters': {'type': 'object'},
|
||||
}],
|
||||
'extra_args': {'temperature': 0.7, 'presence_penalty': 0.1},
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
provider.invoke_llm.assert_awaited_once()
|
||||
kwargs = provider.invoke_llm.await_args.kwargs
|
||||
assert kwargs['query'] is query
|
||||
assert kwargs['extra_args'] == {
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.8,
|
||||
'presence_penalty': 0.1,
|
||||
}
|
||||
assert kwargs['remove_think'] is True
|
||||
assert [tool.name for tool in kwargs['funcs']] == ['search']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_returns_provider_usage(self, app):
|
||||
"""INVOKE_LLM includes optional provider usage in the action response."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
from langbot.pkg.provider.modelmgr import requester as model_requester
|
||||
|
||||
usage = {
|
||||
'prompt_tokens': 11,
|
||||
'completion_tokens': 7,
|
||||
'total_tokens': 18,
|
||||
'prompt_tokens_details': {'cached_tokens': 3},
|
||||
}
|
||||
|
||||
class UsageProvider:
|
||||
async def invoke_llm(self, **kwargs):
|
||||
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
|
||||
return provider_message.Message(role='assistant', content='ok')
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_usage'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[905] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=905,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_usage_001'}]),
|
||||
)
|
||||
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={}),
|
||||
provider=UsageProvider(),
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_usage_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data['message']['content'] == 'ok'
|
||||
assert response.data['usage'] == usage
|
||||
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_stream_restores_query_and_options(self, app):
|
||||
"""INVOKE_LLM_STREAM applies the same host context as non-streaming calls."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
class StreamProvider:
|
||||
def __init__(self):
|
||||
self.kwargs = None
|
||||
|
||||
async def invoke_llm_stream(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
yield provider_message.MessageChunk(role='assistant', content='hi')
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_stream_options'
|
||||
query = self.query(remove_think=False)
|
||||
app.query_pool.cached_queries[902] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=902,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_stream_001'}]),
|
||||
)
|
||||
|
||||
provider = StreamProvider()
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={'max_tokens': 128}),
|
||||
provider=provider,
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
responses = []
|
||||
try:
|
||||
stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_stream_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
'funcs': [{
|
||||
'name': 'search',
|
||||
'human_desc': 'Search',
|
||||
'description': 'Search',
|
||||
'parameters': {'type': 'object'},
|
||||
}],
|
||||
'extra_args': {'max_tokens': 256},
|
||||
'remove_think': True,
|
||||
})
|
||||
async for response in stream:
|
||||
responses.append(response)
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert [response.code for response in responses] == [0]
|
||||
assert provider.kwargs['query'] is query
|
||||
assert provider.kwargs['extra_args'] == {'max_tokens': 256}
|
||||
assert provider.kwargs['remove_think'] is True
|
||||
assert provider.kwargs['funcs'] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_stream_skips_none_chunks(self, app):
|
||||
"""INVOKE_LLM_STREAM tolerates provider heartbeat/no-op chunks."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
class StreamProvider:
|
||||
async def invoke_llm_stream(self, **kwargs):
|
||||
yield provider_message.MessageChunk(role='assistant', content='ok')
|
||||
yield None
|
||||
yield provider_message.MessageChunk(role='assistant', content=' done', is_final=True)
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_stream_none_chunks'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[904] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=904,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_stream_002'}]),
|
||||
)
|
||||
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={}),
|
||||
provider=StreamProvider(),
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
responses = []
|
||||
try:
|
||||
stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_stream_002',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
})
|
||||
async for response in stream:
|
||||
responses.append(response)
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert [response.code for response in responses] == [0, 0]
|
||||
assert [response.data['chunk']['content'] for response in responses] == ['ok', ' done']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_llm_stream_returns_provider_usage_event(self, app):
|
||||
"""INVOKE_LLM_STREAM emits a final usage-only action response when available."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
from langbot.pkg.provider.modelmgr import requester as model_requester
|
||||
|
||||
usage = {
|
||||
'prompt_tokens': 9,
|
||||
'completion_tokens': 4,
|
||||
'total_tokens': 13,
|
||||
'prompt_tokens_details': {'cached_tokens': 2},
|
||||
}
|
||||
|
||||
class StreamProvider:
|
||||
async def invoke_llm_stream(self, **kwargs):
|
||||
yield provider_message.MessageChunk(role='assistant', content='ok')
|
||||
kwargs['query'].variables[model_requester.LLM_USAGE_QUERY_VARIABLE] = usage
|
||||
|
||||
run_id = 'run_proxy_invoke_llm_stream_usage'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[906] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=906,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'llm_stream_usage_001'}]),
|
||||
)
|
||||
|
||||
model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(abilities=[], extra_args={}),
|
||||
provider=StreamProvider(),
|
||||
)
|
||||
app.model_mgr.get_model_by_uuid.return_value = model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
responses = []
|
||||
try:
|
||||
stream = runtime_handler.actions[PluginToRuntimeAction.INVOKE_LLM_STREAM.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'llm_model_uuid': 'llm_stream_usage_001',
|
||||
'messages': [{'role': 'user', 'content': 'hello'}],
|
||||
})
|
||||
async for response in stream:
|
||||
responses.append(response)
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert [response.code for response in responses] == [0, 0]
|
||||
assert responses[0].data['chunk']['content'] == 'ok'
|
||||
assert responses[1].data == {'usage': usage}
|
||||
assert model_requester.LLM_USAGE_QUERY_VARIABLE not in query.variables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_passes_current_query(self, app):
|
||||
"""CALL_TOOL passes the current Query back into tool execution."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_call_tool_query'
|
||||
query = self.query()
|
||||
app.query_pool.cached_queries[903] = query
|
||||
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=903,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(tools=[{'tool_name': 'test/search'}]),
|
||||
)
|
||||
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.CALL_TOOL.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'tool_name': 'test/search',
|
||||
'parameters': {'q': 'langbot'},
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert getattr(query, '_agent_run_session')['run_id'] == run_id
|
||||
app.tool_mgr.execute_func_call.assert_awaited_once_with(
|
||||
name='test/search',
|
||||
parameters={'q': 'langbot'},
|
||||
query=query,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_rerank_uses_authorized_model_and_extra_args(self, app):
|
||||
"""INVOKE_RERANK validates run-scoped model access and merges model extra_args."""
|
||||
from langbot.pkg.agent.runner.session_registry import get_session_registry
|
||||
|
||||
run_id = 'run_proxy_rerank_options'
|
||||
registry = get_session_registry()
|
||||
await registry.unregister(run_id)
|
||||
await registry.register(
|
||||
run_id=run_id,
|
||||
runner_id='plugin:test/runner/default',
|
||||
query_id=904,
|
||||
plugin_identity='test/runner',
|
||||
resources=make_agent_resources(models=[{'model_id': 'rerank_001'}]),
|
||||
)
|
||||
|
||||
provider = SimpleNamespace(
|
||||
invoke_rerank=AsyncMock(return_value=[
|
||||
{'index': 0, 'relevance_score': 0.2},
|
||||
{'index': 1, 'relevance_score': 0.9},
|
||||
]),
|
||||
)
|
||||
rerank_model = SimpleNamespace(
|
||||
model_entity=SimpleNamespace(extra_args={'top_n': 5, 'return_documents': False}),
|
||||
provider=provider,
|
||||
)
|
||||
app.model_mgr.get_rerank_model_by_uuid.return_value = rerank_model
|
||||
runtime_handler = make_handler(app)
|
||||
|
||||
try:
|
||||
response = await runtime_handler.actions[PluginToRuntimeAction.INVOKE_RERANK.value]({
|
||||
'run_id': run_id,
|
||||
'caller_plugin_identity': 'test/runner',
|
||||
'rerank_model_uuid': 'rerank_001',
|
||||
'query': 'hello',
|
||||
'documents': ['a', 'b'],
|
||||
'top_k': 1,
|
||||
'extra_args': {'top_n': 2},
|
||||
})
|
||||
finally:
|
||||
await registry.unregister(run_id)
|
||||
|
||||
assert response.code == 0
|
||||
assert response.data['results'] == [{'index': 1, 'relevance_score': 0.9}]
|
||||
provider.invoke_rerank.assert_awaited_once()
|
||||
kwargs = provider.invoke_rerank.await_args.kwargs
|
||||
assert kwargs['extra_args'] == {'top_n': 2, 'return_documents': False}
|
||||
|
||||
Reference in New Issue
Block a user