mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-14 17:56:03 +00:00
Compare commits
4 Commits
codex/web-
...
codex/spac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a04bfc364 | ||
|
|
e9fe2f2d43 | ||
|
|
27be09ab15 | ||
|
|
1ef4507d9a |
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sqlalchemy
|
||||
import traceback
|
||||
|
||||
@@ -84,8 +85,19 @@ class ModelManager:
|
||||
self.ap.logger.info('LangBot Space Models service is disabled, skipping sync.')
|
||||
return
|
||||
|
||||
sync_timeout = space_config.get('models_sync_timeout')
|
||||
try:
|
||||
await self.sync_new_models_from_space()
|
||||
if sync_timeout:
|
||||
await asyncio.wait_for(
|
||||
self.sync_new_models_from_space(),
|
||||
timeout=float(sync_timeout),
|
||||
)
|
||||
else:
|
||||
await self.sync_new_models_from_space()
|
||||
except asyncio.TimeoutError:
|
||||
self.ap.logger.warning(
|
||||
f'LangBot Space model sync timed out after {sync_timeout}s, skipping startup sync.'
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.warning('Failed to sync new models from LangBot Space, model list may not be updated.')
|
||||
self.ap.logger.warning(f' - Error: {e}')
|
||||
|
||||
@@ -12,6 +12,19 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
LLM_USAGE_QUERY_VARIABLE = '_llm_usage'
|
||||
STREAM_USAGE_QUERY_VARIABLE = '_stream_usage'
|
||||
|
||||
|
||||
def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None:
|
||||
"""Store the latest provider usage on the query for upstream action handlers."""
|
||||
if query is None or not usage_info:
|
||||
return
|
||||
if query.variables is None:
|
||||
query.variables = {}
|
||||
query.variables[LLM_USAGE_QUERY_VARIABLE] = dict(usage_info)
|
||||
|
||||
|
||||
class RuntimeProvider:
|
||||
"""运行时模型提供商"""
|
||||
|
||||
@@ -67,6 +80,7 @@ class RuntimeProvider:
|
||||
if isinstance(result, tuple):
|
||||
msg, usage_info = result
|
||||
if usage_info:
|
||||
_store_llm_usage(query, usage_info)
|
||||
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||
output_tokens = usage_info.get('completion_tokens', 0)
|
||||
return msg
|
||||
@@ -146,11 +160,12 @@ class RuntimeProvider:
|
||||
if query:
|
||||
if query.variables is None:
|
||||
query.variables = {}
|
||||
if '_stream_usage' in query.variables:
|
||||
usage_info = query.variables['_stream_usage']
|
||||
if STREAM_USAGE_QUERY_VARIABLE in query.variables:
|
||||
usage_info = query.variables[STREAM_USAGE_QUERY_VARIABLE]
|
||||
_store_llm_usage(query, usage_info)
|
||||
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||
output_tokens = usage_info.get('completion_tokens', 0)
|
||||
del query.variables['_stream_usage']
|
||||
del query.variables[STREAM_USAGE_QUERY_VARIABLE]
|
||||
except Exception as e:
|
||||
status = 'error'
|
||||
error_message = str(e)
|
||||
|
||||
@@ -262,32 +262,82 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
- dict with the same keys
|
||||
- missing ``total_tokens`` (derived from prompt + completion)
|
||||
- ``None`` / partially-populated usage (defaults to 0)
|
||||
- provider-specific token details, including cache token counters
|
||||
"""
|
||||
if usage is None:
|
||||
return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
|
||||
|
||||
def _get(key: str) -> typing.Any:
|
||||
if isinstance(usage, dict):
|
||||
return usage.get(key)
|
||||
return getattr(usage, key, None)
|
||||
def _plain_value(value: typing.Any) -> typing.Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, dict):
|
||||
return {k: _plain_value(v) for k, v in value.items() if v is not None}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_plain_value(v) for v in value]
|
||||
|
||||
prompt_tokens = _get('prompt_tokens') or 0
|
||||
completion_tokens = _get('completion_tokens') or 0
|
||||
total_tokens = _get('total_tokens') or 0
|
||||
model_dump = getattr(value, 'model_dump', None)
|
||||
if callable(model_dump):
|
||||
try:
|
||||
dumped = model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return _plain_value(dumped)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
def _usage_dict(value: typing.Any) -> dict[str, typing.Any]:
|
||||
if value is None:
|
||||
return {}
|
||||
plain = _plain_value(value)
|
||||
if isinstance(plain, dict):
|
||||
return plain
|
||||
|
||||
def _is_mock_attr(attr: typing.Any) -> bool:
|
||||
return type(attr).__module__.startswith('unittest.mock')
|
||||
|
||||
data: dict[str, typing.Any] = {}
|
||||
for key in (
|
||||
'prompt_tokens',
|
||||
'completion_tokens',
|
||||
'total_tokens',
|
||||
'prompt_tokens_details',
|
||||
'completion_tokens_details',
|
||||
'cache_creation_input_tokens',
|
||||
'cache_read_input_tokens',
|
||||
'input_token_details',
|
||||
'output_token_details',
|
||||
):
|
||||
attr_value = getattr(value, key, None)
|
||||
if attr_value is not None and not _is_mock_attr(attr_value):
|
||||
data[key] = _plain_value(attr_value)
|
||||
return data
|
||||
|
||||
def _to_int(value: typing.Any) -> int:
|
||||
try:
|
||||
return int(value or 0)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
normalized = _usage_dict(usage)
|
||||
|
||||
prompt_tokens = _to_int(normalized.get('prompt_tokens'))
|
||||
completion_tokens = _to_int(normalized.get('completion_tokens'))
|
||||
total_tokens = _to_int(normalized.get('total_tokens'))
|
||||
|
||||
# Some providers omit total_tokens in streaming usage; derive it.
|
||||
if not total_tokens:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return {
|
||||
'prompt_tokens': int(prompt_tokens),
|
||||
'completion_tokens': int(completion_tokens),
|
||||
'total_tokens': int(total_tokens),
|
||||
}
|
||||
normalized['prompt_tokens'] = prompt_tokens
|
||||
normalized['completion_tokens'] = completion_tokens
|
||||
normalized['total_tokens'] = total_tokens
|
||||
return normalized
|
||||
|
||||
def _extract_usage(self, response) -> dict:
|
||||
def _extract_usage(self, response) -> dict | None:
|
||||
"""Extract usage info from a non-streaming LiteLLM response."""
|
||||
return self._normalize_usage(getattr(response, 'usage', None))
|
||||
usage = getattr(response, 'usage', None)
|
||||
if usage is None:
|
||||
return None
|
||||
return self._normalize_usage(usage)
|
||||
|
||||
@staticmethod
|
||||
def _as_dict(value: typing.Any) -> dict:
|
||||
@@ -486,7 +536,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
if query is not None:
|
||||
if query.variables is None:
|
||||
query.variables = {}
|
||||
query.variables['_stream_usage'] = usage_info
|
||||
query.variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage_info
|
||||
|
||||
if not hasattr(chunk, 'choices') or not chunk.choices:
|
||||
continue
|
||||
|
||||
6
src/langbot/pkg/provider/tools/errors.py
Normal file
6
src/langbot/pkg/provider/tools/errors.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class ToolNotFoundError(ValueError):
|
||||
"""Raised when a requested tool cannot be found in any active loader."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
super().__init__(f'Tool not found: {name}')
|
||||
@@ -4,12 +4,15 @@ import abc
|
||||
import typing
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langbot_plugin.api.definition.components.manifest import ComponentManifest
|
||||
from langbot_plugin.api.entities.events import pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...core import app
|
||||
|
||||
ToolLookupResult = resource_tool.LLMTool | ComponentManifest
|
||||
|
||||
|
||||
preregistered_loaders: list[typing.Type[ToolLoader]] = []
|
||||
|
||||
@@ -43,6 +46,13 @@ class ToolLoader(abc.ABC):
|
||||
"""获取所有工具"""
|
||||
pass
|
||||
|
||||
async def get_tool(self, name: str) -> ToolLookupResult | None:
|
||||
"""Get one tool by name."""
|
||||
for tool in await self.get_tools():
|
||||
if tool.name == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def has_tool(self, name: str) -> bool:
|
||||
"""检查工具是否存在"""
|
||||
|
||||
@@ -567,6 +567,13 @@ class MCPLoader(loader.ToolLoader):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_tool(self, name: str) -> resource_tool.LLMTool | None:
|
||||
for session in self.sessions.values():
|
||||
for function in session.get_tools():
|
||||
if function.name == name:
|
||||
return function
|
||||
return None
|
||||
|
||||
async def invoke_tool(self, name: str, parameters: dict, query: pipeline_query.Query) -> typing.Any:
|
||||
"""执行工具调用"""
|
||||
for session in self.sessions.values():
|
||||
|
||||
@@ -7,6 +7,7 @@ import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
from langbot_plugin.api.entities.events import pipeline_query
|
||||
|
||||
from .. import loader
|
||||
from ..errors import ToolNotFoundError
|
||||
from . import skill as skill_loader
|
||||
|
||||
EXEC_TOOL_NAME = 'exec'
|
||||
@@ -90,7 +91,7 @@ class NativeToolLoader(loader.ToolLoader):
|
||||
return await self._invoke_glob(parameters, query)
|
||||
if name == GREP_TOOL_NAME:
|
||||
return await self._invoke_grep(parameters, query)
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
raise ToolNotFoundError(name)
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from langbot_plugin.api.definition.components.manifest import ComponentManifest
|
||||
from langbot_plugin.api.entities.events import pipeline_query
|
||||
|
||||
from .. import loader
|
||||
@@ -39,7 +40,7 @@ class PluginToolLoader(loader.ToolLoader):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _get_tool(self, name: str) -> resource_tool.LLMTool:
|
||||
async def get_tool(self, name: str) -> ComponentManifest | None:
|
||||
for tool in await self.ap.plugin_connector.list_tools():
|
||||
if tool.metadata.name == name:
|
||||
return tool
|
||||
|
||||
@@ -6,6 +6,9 @@ from typing import TYPE_CHECKING
|
||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||
from langbot_plugin.api.entities.events import pipeline_query
|
||||
|
||||
from . import loader as tool_loader
|
||||
from .errors import ToolNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...core import app
|
||||
from langbot.pkg.provider.tools.loaders import (
|
||||
@@ -67,6 +70,20 @@ class ToolManager:
|
||||
|
||||
return all_functions
|
||||
|
||||
async def get_tool_by_name(self, name: str) -> tool_loader.ToolLookupResult | None:
|
||||
"""Get tool by name from any active loader."""
|
||||
for active_loader in (
|
||||
self.native_tool_loader,
|
||||
self.plugin_tool_loader,
|
||||
self.mcp_tool_loader,
|
||||
self.skill_tool_loader,
|
||||
):
|
||||
tool = await active_loader.get_tool(name)
|
||||
if tool:
|
||||
return tool
|
||||
|
||||
return None
|
||||
|
||||
async def generate_tools_for_openai(self, use_funcs: list[resource_tool.LLMTool]) -> list:
|
||||
tools = []
|
||||
|
||||
@@ -98,7 +115,7 @@ class ToolManager:
|
||||
if await self.skill_tool_loader.has_tool(name):
|
||||
telemetry_features.increment(query, 'tool_calls', 'skill')
|
||||
return await self.skill_tool_loader.invoke_tool(name, parameters, query)
|
||||
raise ValueError(f'未找到工具: {name}')
|
||||
raise ToolNotFoundError(name)
|
||||
|
||||
async def shutdown(self):
|
||||
await self.native_tool_loader.shutdown()
|
||||
|
||||
@@ -115,6 +115,15 @@ class TestExtractUsage:
|
||||
assert result['prompt_tokens'] == 0
|
||||
assert result['completion_tokens'] == 0
|
||||
|
||||
def test_extract_usage_without_provider_usage(self):
|
||||
"""Missing provider usage is not treated as authoritative zero usage."""
|
||||
requester = litellmchat.LiteLLMRequester(ap=Mock(), config={})
|
||||
|
||||
response = Mock()
|
||||
response.usage = None
|
||||
|
||||
assert requester._extract_usage(response) is None
|
||||
|
||||
|
||||
class TestNormalizeUsage:
|
||||
"""Test _normalize_usage helper covering real-world usage shapes"""
|
||||
@@ -131,6 +140,22 @@ class TestNormalizeUsage:
|
||||
)
|
||||
assert result == {'prompt_tokens': 12, 'completion_tokens': 8, 'total_tokens': 20}
|
||||
|
||||
def test_preserves_token_details(self):
|
||||
"""Provider token details such as cache counters are preserved."""
|
||||
result = litellmchat.LiteLLMRequester._normalize_usage(
|
||||
{
|
||||
'prompt_tokens': 12,
|
||||
'completion_tokens': 8,
|
||||
'total_tokens': 20,
|
||||
'prompt_tokens_details': {'cached_tokens': 7},
|
||||
'completion_tokens_details': {'reasoning_tokens': 3},
|
||||
}
|
||||
)
|
||||
|
||||
assert result['prompt_tokens'] == 12
|
||||
assert result['prompt_tokens_details'] == {'cached_tokens': 7}
|
||||
assert result['completion_tokens_details'] == {'reasoning_tokens': 3}
|
||||
|
||||
def test_missing_total_is_derived(self):
|
||||
"""When total_tokens is absent/zero it is derived from prompt + completion"""
|
||||
usage = Mock()
|
||||
@@ -166,9 +191,7 @@ class TestInvokeLLMStreamUsage:
|
||||
if has_choice:
|
||||
choice = Mock()
|
||||
delta = Mock()
|
||||
delta.model_dump = Mock(
|
||||
return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls}
|
||||
)
|
||||
delta.model_dump = Mock(return_value={'role': 'assistant', 'content': content, 'tool_calls': tool_calls})
|
||||
choice.delta = delta
|
||||
choice.finish_reason = finish_reason
|
||||
chunk.choices = [choice]
|
||||
@@ -313,7 +336,8 @@ class TestInvokeLLMStreamUsage:
|
||||
|
||||
with patch.object(litellmchat, 'acompletion', new=AsyncMock(side_effect=lambda **kw: _aiter())):
|
||||
collected = [
|
||||
chunk async for chunk in requester.invoke_llm_stream(
|
||||
chunk
|
||||
async for chunk in requester.invoke_llm_stream(
|
||||
query=query,
|
||||
model=model,
|
||||
messages=messages,
|
||||
@@ -788,7 +812,9 @@ class TestInvokeRerank:
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
# arerank must NOT be called on the openai-compatible path
|
||||
with patch.object(
|
||||
litellmchat, 'arerank', new_callable=AsyncMock,
|
||||
litellmchat,
|
||||
'arerank',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=AssertionError('arerank must not be used for openai-compatible provider'),
|
||||
):
|
||||
results = await requester.invoke_rerank(
|
||||
@@ -1068,8 +1094,7 @@ class TestScanModels:
|
||||
|
||||
with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling:
|
||||
mock_supports_function_calling.side_effect = (
|
||||
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6'
|
||||
and custom_llm_provider is None
|
||||
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6' and custom_llm_provider is None
|
||||
)
|
||||
|
||||
assert requester._supports_function_calling('kimi-k2.6') is True
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
@@ -88,6 +89,28 @@ def test_token_manager_next_token_ignores_empty_token_list():
|
||||
assert token_mgr.using_token_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_manager_initialize_skips_space_sync_after_timeout():
|
||||
ap = SimpleNamespace()
|
||||
ap.discover = SimpleNamespace(get_components_by_kind=Mock(return_value=[]))
|
||||
ap.instance_config = SimpleNamespace(data={'space': {'models_sync_timeout': 0.01}})
|
||||
ap.logger = Mock()
|
||||
|
||||
mgr = ModelManager(ap)
|
||||
mgr.load_models_from_db = AsyncMock()
|
||||
|
||||
async def slow_sync():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
mgr.sync_new_models_from_space = AsyncMock(side_effect=slow_sync)
|
||||
|
||||
await mgr.initialize()
|
||||
|
||||
mgr.load_models_from_db.assert_awaited_once()
|
||||
mgr.sync_new_models_from_space.assert_awaited_once()
|
||||
ap.logger.warning.assert_any_call('LangBot Space model sync timed out after 0.01s, skipping startup sync.')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline():
|
||||
from langbot.pkg.api.http.service.model import LLMModelsService
|
||||
|
||||
@@ -226,7 +226,7 @@ class TestToolManagerExecuteFuncCall:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_raises_when_tool_not_found(self, mock_app_with_loaders, sample_query):
|
||||
"""Test that execute_func_call raises ValueError when tool not found."""
|
||||
"""Test that execute_func_call raises ToolNotFoundError when tool not found."""
|
||||
toolmgr = get_toolmgr_module()
|
||||
|
||||
mock_app, mock_plugin_loader, mock_mcp_loader = mock_app_with_loaders
|
||||
@@ -236,7 +236,7 @@ class TestToolManagerExecuteFuncCall:
|
||||
manager = toolmgr.ToolManager(mock_app)
|
||||
self._wire_loaders(manager, mock_app, mock_plugin_loader, mock_mcp_loader)
|
||||
|
||||
with pytest.raises(ValueError, match='未找到工具'):
|
||||
with pytest.raises(toolmgr.ToolNotFoundError, match='Tool not found: unknown_tool'):
|
||||
await manager.execute_func_call('unknown_tool', {}, sample_query)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user