Compare commits

..

2 Commits

Author SHA1 Message Date
huanghuoguoguo
516371a6c1 style(platform): format web page bot adapter 2026-06-14 10:53:47 +08:00
huanghuoguoguo
4f8e2496d2 fix(platform): delegate web page bot stream helpers 2026-06-14 10:45:23 +08:00
12 changed files with 38 additions and 210 deletions

View File

@@ -197,11 +197,10 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
self,
file_bytes: bytes,
task_context: taskmgr.TaskContext | None,
) -> tuple[str | None, str | None, str | None]:
) -> tuple[str | None, str | None]:
"""Extract plugin identity and dependency metadata from a plugin package."""
plugin_author = None
plugin_name = None
plugin_version = None
try:
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
@@ -210,7 +209,6 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
metadata = manifest.get('metadata', {})
plugin_author = metadata.get('author')
plugin_name = metadata.get('name')
plugin_version = metadata.get('version')
except Exception:
pass
@@ -229,7 +227,7 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
except Exception:
pass
return plugin_author, plugin_name, plugin_version
return plugin_author, plugin_name
async def _install_mcp_from_marketplace(
self,
@@ -371,7 +369,6 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
):
plugin_author = install_info.get('plugin_author')
plugin_name = install_info.get('plugin_name')
plugin_file_transferred = False
if install_source == PluginInstallSource.MARKETPLACE:
# Handle marketplace plugin/mcp/skill installation
@@ -466,18 +463,9 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
)
file_bytes = download_resp.content
plugin_author, plugin_name, plugin_version = self._inspect_plugin_package(
file_bytes,
task_context,
)
if task_context is not None and plugin_author and plugin_name:
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
if task_context is not None and plugin_version:
task_context.metadata['plugin_version'] = plugin_version
self._inspect_plugin_package(file_bytes, task_context)
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
install_info['plugin_file_key'] = file_key
install_source = PluginInstallSource.LOCAL
plugin_file_transferred = True
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')
# Continue to install via runtime
else:
@@ -493,14 +481,12 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
mcp_resp.raise_for_status()
raise Exception(f'Failed to get MCP {plugin_author}/{plugin_name}')
if install_source == PluginInstallSource.LOCAL and not plugin_file_transferred:
if install_source == PluginInstallSource.LOCAL:
# transfer file before install
file_bytes = install_info['plugin_file']
plugin_author, plugin_name, plugin_version = self._inspect_plugin_package(file_bytes, task_context)
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
if task_context is not None and plugin_author and plugin_name:
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
if task_context is not None and plugin_version:
task_context.metadata['plugin_version'] = plugin_version
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
install_info['plugin_file_key'] = file_key
del install_info['plugin_file']
@@ -537,11 +523,9 @@ class PluginRuntimeConnector(ManagedRuntimeConnector):
task_context.metadata['download_speed'] = downloaded / elapsed if elapsed > 0 else 0
file_bytes = b''.join(chunks)
plugin_author, plugin_name, plugin_version = self._inspect_plugin_package(file_bytes, task_context)
plugin_author, plugin_name = self._inspect_plugin_package(file_bytes, task_context)
if task_context is not None and plugin_author and plugin_name:
task_context.metadata['plugin_name'] = f'{plugin_author}/{plugin_name}'
if task_context is not None and plugin_version:
task_context.metadata['plugin_version'] = plugin_version
file_key = await self.handler.send_file(file_bytes, 'lbpkg')
install_info['plugin_file_key'] = file_key
self.ap.logger.info(f'Transfered file {file_key} to plugin runtime')

View File

@@ -12,19 +12,6 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.provider.message as provider_message
LLM_USAGE_QUERY_VARIABLE = '_llm_usage'
STREAM_USAGE_QUERY_VARIABLE = '_stream_usage'
def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None:
"""Store the latest provider usage on the query for upstream action handlers."""
if query is None or not usage_info:
return
if query.variables is None:
query.variables = {}
query.variables[LLM_USAGE_QUERY_VARIABLE] = dict(usage_info)
class RuntimeProvider:
"""运行时模型提供商"""
@@ -80,7 +67,6 @@ class RuntimeProvider:
if isinstance(result, tuple):
msg, usage_info = result
if usage_info:
_store_llm_usage(query, usage_info)
input_tokens = usage_info.get('prompt_tokens', 0)
output_tokens = usage_info.get('completion_tokens', 0)
return msg
@@ -160,12 +146,11 @@ class RuntimeProvider:
if query:
if query.variables is None:
query.variables = {}
if STREAM_USAGE_QUERY_VARIABLE in query.variables:
usage_info = query.variables[STREAM_USAGE_QUERY_VARIABLE]
_store_llm_usage(query, usage_info)
if '_stream_usage' in query.variables:
usage_info = query.variables['_stream_usage']
input_tokens = usage_info.get('prompt_tokens', 0)
output_tokens = usage_info.get('completion_tokens', 0)
del query.variables[STREAM_USAGE_QUERY_VARIABLE]
del query.variables['_stream_usage']
except Exception as e:
status = 'error'
error_message = str(e)

View File

@@ -262,82 +262,32 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
- dict with the same keys
- missing ``total_tokens`` (derived from prompt + completion)
- ``None`` / partially-populated usage (defaults to 0)
- provider-specific token details, including cache token counters
"""
if usage is None:
return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
def _plain_value(value: typing.Any) -> typing.Any:
if value is None:
return None
if isinstance(value, dict):
return {k: _plain_value(v) for k, v in value.items() if v is not None}
if isinstance(value, (list, tuple)):
return [_plain_value(v) for v in value]
def _get(key: str) -> typing.Any:
if isinstance(usage, dict):
return usage.get(key)
return getattr(usage, key, None)
model_dump = getattr(value, 'model_dump', None)
if callable(model_dump):
try:
dumped = model_dump()
if isinstance(dumped, dict):
return _plain_value(dumped)
except Exception:
pass
return value
def _usage_dict(value: typing.Any) -> dict[str, typing.Any]:
if value is None:
return {}
plain = _plain_value(value)
if isinstance(plain, dict):
return plain
def _is_mock_attr(attr: typing.Any) -> bool:
return type(attr).__module__.startswith('unittest.mock')
data: dict[str, typing.Any] = {}
for key in (
'prompt_tokens',
'completion_tokens',
'total_tokens',
'prompt_tokens_details',
'completion_tokens_details',
'cache_creation_input_tokens',
'cache_read_input_tokens',
'input_token_details',
'output_token_details',
):
attr_value = getattr(value, key, None)
if attr_value is not None and not _is_mock_attr(attr_value):
data[key] = _plain_value(attr_value)
return data
def _to_int(value: typing.Any) -> int:
try:
return int(value or 0)
except (TypeError, ValueError):
return 0
normalized = _usage_dict(usage)
prompt_tokens = _to_int(normalized.get('prompt_tokens'))
completion_tokens = _to_int(normalized.get('completion_tokens'))
total_tokens = _to_int(normalized.get('total_tokens'))
prompt_tokens = _get('prompt_tokens') or 0
completion_tokens = _get('completion_tokens') or 0
total_tokens = _get('total_tokens') or 0
# Some providers omit total_tokens in streaming usage; derive it.
if not total_tokens:
total_tokens = prompt_tokens + completion_tokens
normalized['prompt_tokens'] = prompt_tokens
normalized['completion_tokens'] = completion_tokens
normalized['total_tokens'] = total_tokens
return normalized
return {
'prompt_tokens': int(prompt_tokens),
'completion_tokens': int(completion_tokens),
'total_tokens': int(total_tokens),
}
def _extract_usage(self, response) -> dict | None:
def _extract_usage(self, response) -> dict:
"""Extract usage info from a non-streaming LiteLLM response."""
usage = getattr(response, 'usage', None)
if usage is None:
return None
return self._normalize_usage(usage)
return self._normalize_usage(getattr(response, 'usage', None))
@staticmethod
def _as_dict(value: typing.Any) -> dict:
@@ -536,7 +486,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
if query is not None:
if query.variables is None:
query.variables = {}
query.variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage_info
query.variables['_stream_usage'] = usage_info
if not hasattr(chunk, 'choices') or not chunk.choices:
continue

View File

@@ -1,6 +0,0 @@
class ToolNotFoundError(ValueError):
"""Raised when a requested tool cannot be found in any active loader."""
def __init__(self, name: str):
self.name = name
super().__init__(f'Tool not found: {name}')

View File

@@ -4,15 +4,12 @@ import abc
import typing
from typing import TYPE_CHECKING
from langbot_plugin.api.definition.components.manifest import ComponentManifest
from langbot_plugin.api.entities.events import pipeline_query
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
if TYPE_CHECKING:
from ...core import app
ToolLookupResult = resource_tool.LLMTool | ComponentManifest
preregistered_loaders: list[typing.Type[ToolLoader]] = []
@@ -46,13 +43,6 @@ class ToolLoader(abc.ABC):
"""获取所有工具"""
pass
async def get_tool(self, name: str) -> ToolLookupResult | None:
"""Get one tool by name."""
for tool in await self.get_tools():
if tool.name == name:
return tool
return None
@abc.abstractmethod
async def has_tool(self, name: str) -> bool:
"""检查工具是否存在"""

View File

@@ -567,13 +567,6 @@ class MCPLoader(loader.ToolLoader):
return True
return False
async def get_tool(self, name: str) -> resource_tool.LLMTool | None:
for session in self.sessions.values():
for function in session.get_tools():
if function.name == name:
return function
return None
async def invoke_tool(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
"""执行工具调用"""
for session in self.sessions.values():

View File

@@ -7,7 +7,6 @@ import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from langbot_plugin.api.entities.events import pipeline_query
from .. import loader
from ..errors import ToolNotFoundError
from . import skill as skill_loader
EXEC_TOOL_NAME = 'exec'
@@ -91,7 +90,7 @@ class NativeToolLoader(loader.ToolLoader):
return await self._invoke_glob(parameters, query)
if name == GREP_TOOL_NAME:
return await self._invoke_grep(parameters, query)
raise ToolNotFoundError(name)
raise ValueError(f'未找到工具: {name}')
async def shutdown(self):
pass

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import typing
import traceback
from langbot_plugin.api.definition.components.manifest import ComponentManifest
from langbot_plugin.api.entities.events import pipeline_query
from .. import loader
@@ -40,7 +39,7 @@ class PluginToolLoader(loader.ToolLoader):
return True
return False
async def get_tool(self, name: str) -> ComponentManifest | None:
async def _get_tool(self, name: str) -> resource_tool.LLMTool:
for tool in await self.ap.plugin_connector.list_tools():
if tool.metadata.name == name:
return tool

View File

@@ -6,9 +6,6 @@ from typing import TYPE_CHECKING
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from langbot_plugin.api.entities.events import pipeline_query
from . import loader as tool_loader
from .errors import ToolNotFoundError
if TYPE_CHECKING:
from ...core import app
from langbot.pkg.provider.tools.loaders import (
@@ -70,20 +67,6 @@ class ToolManager:
return all_functions
async def get_tool_by_name(self, name: str) -> tool_loader.ToolLookupResult | None:
"""Get tool by name from any active loader."""
for active_loader in (
self.native_tool_loader,
self.plugin_tool_loader,
self.mcp_tool_loader,
self.skill_tool_loader,
):
tool = await active_loader.get_tool(name)
if tool:
return tool
return None
async def generate_tools_for_openai(self, use_funcs: list[resource_tool.LLMTool]) -> list:
tools = []
@@ -115,7 +98,7 @@ class ToolManager:
if await self.skill_tool_loader.has_tool(name):
telemetry_features.increment(query, 'tool_calls', 'skill')
return await self.skill_tool_loader.invoke_tool(name, parameters, query)
raise ToolNotFoundError(name)
raise ValueError(f'未找到工具: {name}')
async def shutdown(self):
await self.native_tool_loader.shutdown()

View File

@@ -49,30 +49,6 @@ class TestExtractDepsMetadata:
assert 'flask' in task_context.metadata['deps_list']
assert 'numpy' in task_context.metadata['deps_list']
def test_extract_plugin_identity_includes_version(self):
"""Extract plugin identity and version from manifest.yaml."""
connector = self._create_connector()
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zf:
zf.writestr(
'manifest.yaml',
'\n'.join(
[
'metadata:',
' author: langbot-team',
' name: LangRAG',
' version: 0.1.8',
]
),
)
assert connector._inspect_plugin_package(zip_buffer.getvalue(), None) == (
'langbot-team',
'LangRAG',
'0.1.8',
)
def test_extract_deps_empty_requirements(self):
"""Handle empty requirements.txt."""
connector = self._create_connector()

View File

@@ -115,15 +115,6 @@ class TestExtractUsage:
assert result['prompt_tokens'] == 0
assert result['completion_tokens'] == 0
def test_extract_usage_without_provider_usage(self):
"""Missing provider usage is not treated as authoritative zero usage."""
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
response = Mock()
response.usage = None
assert requester._extract_usage(response) is None
class TestNormalizeUsage:
"""Test _normalize_usage helper covering real-world usage shapes"""
@@ -140,22 +131,6 @@ class TestNormalizeUsage:
)
assert result == {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20}
def test_preserves_token_details(self):
"""Provider token details such as cache counters are preserved."""
result = litellmchat.LiteLLMRequester._normalize_usage(
{
'prompt_tokens': 12,
'completion_tokens': 8,
'total_tokens': 20,
'prompt_tokens_details': {'cached_tokens': 7},
'completion_tokens_details': {'reasoning_tokens': 3},
}
)
assert result['prompt_tokens'] == 12
assert result['prompt_tokens_details'] == {'cached_tokens': 7}
assert result['completion_tokens_details'] == {'reasoning_tokens': 3}
def test_missing_total_is_derived(self):
"""When total_tokens is absent/zero it is derived from prompt + completion"""
usage = Mock()
@@ -191,7 +166,9 @@ class TestInvokeLLMStreamUsage:
if has_choice:
choice = Mock()
delta = Mock()
delta.model_dump = Mock(return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls})
delta.model_dump = Mock(
return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls}
)
choice.delta = delta
choice.finish_reason = finish_reason
chunk.choices = [choice]
@@ -336,8 +313,7 @@ class TestInvokeLLMStreamUsage:
with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())):
collected = [
chunk
async for chunk in requester.invoke_llm_stream(
chunk async for chunk in requester.invoke_llm_stream(
query=query,
model=model,
messages=messages,
@@ -812,9 +788,7 @@ class TestInvokeRerank:
with patch('httpx.AsyncClient', return_value=mock_client):
# arerank must NOT be called on the openai-compatible path
with patch.object(
litellmchat,
'arerank',
new_callable=AsyncMock,
litellmchat, 'arerank', new_callable=AsyncMock,
side_effect=AssertionError('arerank must not be used for openai-compatible provider'),
):
results = await requester.invoke_rerank(
@@ -1094,7 +1068,8 @@ class TestScanModels:
with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling:
mock_supports_function_calling.side_effect = (
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6' and custom_llm_provider is None
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6'
and custom_llm_provider is None
)
assert requester._supports_function_calling('kimi-k2.6') is True

View File

@@ -226,7 +226,7 @@ class TestToolManagerExecuteFuncCall:
@pytest.mark.asyncio
async def test_execute_raises_when_tool_not_found(self, mock_app_with_loaders, sample_query):
"""Test that execute_func_call raises ToolNotFoundError when tool not found."""
"""Test that execute_func_call raises ValueError when tool not found."""
toolmgr = get_toolmgr_module()
mock_app, mock_plugin_loader, mock_mcp_loader = mock_app_with_loaders
@@ -236,7 +236,7 @@ class TestToolManagerExecuteFuncCall:
manager = toolmgr.ToolManager(mock_app)
self._wire_loaders(manager, mock_app, mock_plugin_loader, mock_mcp_loader)
with pytest.raises(toolmgr.ToolNotFoundError, match='Tool not found: unknown_tool'):
with pytest.raises(ValueError, match='未找到工具'):
await manager.execute_func_call('unknown_tool', {}, sample_query)
@pytest.mark.asyncio