mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-28 00:14:21 +00:00
chore(agent-runner): merge split tool runtime base
# Conflicts: # src/langbot/pkg/box/workspace.py # src/langbot/pkg/provider/tools/loaders/mcp_stdio.py # src/langbot/pkg/provider/tools/loaders/native.py # src/langbot/pkg/provider/tools/loaders/skill.py # tests/unit_tests/box/test_workspace.py # tests/unit_tests/provider/test_mcp_box_integration.py
This commit is contained in:
@@ -12,6 +12,19 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
|
||||
|
||||
LLM_USAGE_QUERY_VARIABLE = '_llm_usage'
|
||||
STREAM_USAGE_QUERY_VARIABLE = '_stream_usage'
|
||||
|
||||
|
||||
def _store_llm_usage(query: pipeline_query.Query | None, usage_info: dict | None) -> None:
|
||||
"""Store the latest provider usage on the query for upstream action handlers."""
|
||||
if query is None or not usage_info:
|
||||
return
|
||||
if query.variables is None:
|
||||
query.variables = {}
|
||||
query.variables[LLM_USAGE_QUERY_VARIABLE] = dict(usage_info)
|
||||
|
||||
|
||||
class RuntimeProvider:
|
||||
"""运行时模型提供商"""
|
||||
|
||||
@@ -67,6 +80,7 @@ class RuntimeProvider:
|
||||
if isinstance(result, tuple):
|
||||
msg, usage_info = result
|
||||
if usage_info:
|
||||
_store_llm_usage(query, usage_info)
|
||||
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||
output_tokens = usage_info.get('completion_tokens', 0)
|
||||
return msg
|
||||
@@ -146,11 +160,12 @@ class RuntimeProvider:
|
||||
if query:
|
||||
if query.variables is None:
|
||||
query.variables = {}
|
||||
if '_stream_usage' in query.variables:
|
||||
usage_info = query.variables['_stream_usage']
|
||||
if STREAM_USAGE_QUERY_VARIABLE in query.variables:
|
||||
usage_info = query.variables[STREAM_USAGE_QUERY_VARIABLE]
|
||||
_store_llm_usage(query, usage_info)
|
||||
input_tokens = usage_info.get('prompt_tokens', 0)
|
||||
output_tokens = usage_info.get('completion_tokens', 0)
|
||||
del query.variables['_stream_usage']
|
||||
del query.variables[STREAM_USAGE_QUERY_VARIABLE]
|
||||
except Exception as e:
|
||||
status = 'error'
|
||||
error_message = str(e)
|
||||
|
||||
@@ -75,22 +75,33 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
continue
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _positive_int(value: typing.Any) -> int | None:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if isinstance(value, int) and value > 0:
|
||||
return value
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
parsed_value = int(value)
|
||||
if parsed_value > 0:
|
||||
return parsed_value
|
||||
return None
|
||||
|
||||
def _context_length_from_scan_payload(self, model_payload: dict[str, typing.Any] | None) -> int | None:
|
||||
if not model_payload:
|
||||
return None
|
||||
|
||||
for field_name in ('context_length', 'context_window', 'max_context_length'):
|
||||
value = model_payload.get(field_name)
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int) and value > 0:
|
||||
return value
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
parsed_value = int(value)
|
||||
if parsed_value > 0:
|
||||
return parsed_value
|
||||
context_length = self._positive_int(model_payload.get(field_name))
|
||||
if context_length is not None:
|
||||
return context_length
|
||||
return None
|
||||
|
||||
def _context_length_from_litellm_model_info(self, model_info: typing.Any) -> int | None:
|
||||
if isinstance(model_info, dict):
|
||||
return self._positive_int(model_info.get('max_input_tokens'))
|
||||
return self._positive_int(getattr(model_info, 'max_input_tokens', None))
|
||||
|
||||
def _metadata_provider_candidates(self, model_name: str) -> list[str]:
|
||||
normalized_model_name = (model_name or '').lower()
|
||||
candidates = []
|
||||
@@ -126,7 +137,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
return None
|
||||
|
||||
def _safe_context_length(self, model_name: str) -> int | None:
|
||||
helper = getattr(litellm, 'get_max_tokens', None)
|
||||
helper = getattr(litellm, 'get_model_info', None)
|
||||
if not callable(helper):
|
||||
return self._known_context_length_fallback(model_name)
|
||||
|
||||
@@ -143,11 +154,12 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
continue
|
||||
tried_candidates.append(candidate)
|
||||
try:
|
||||
max_tokens = helper(candidate)
|
||||
model_info = helper(candidate)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(max_tokens, int) and max_tokens > 0:
|
||||
return max_tokens
|
||||
context_length = self._context_length_from_litellm_model_info(model_info)
|
||||
if context_length is not None:
|
||||
return context_length
|
||||
return self._known_context_length_fallback(model_name)
|
||||
|
||||
def _supports_function_calling(self, model_name: str) -> bool:
|
||||
@@ -250,32 +262,82 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
- dict with the same keys
|
||||
- missing ``total_tokens`` (derived from prompt + completion)
|
||||
- ``None`` / partially-populated usage (defaults to 0)
|
||||
- provider-specific token details, including cache token counters
|
||||
"""
|
||||
if usage is None:
|
||||
return {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0}
|
||||
|
||||
def _get(key: str) -> typing.Any:
|
||||
if isinstance(usage, dict):
|
||||
return usage.get(key)
|
||||
return getattr(usage, key, None)
|
||||
def _plain_value(value: typing.Any) -> typing.Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, dict):
|
||||
return {k: _plain_value(v) for k, v in value.items() if v is not None}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_plain_value(v) for v in value]
|
||||
|
||||
prompt_tokens = _get('prompt_tokens') or 0
|
||||
completion_tokens = _get('completion_tokens') or 0
|
||||
total_tokens = _get('total_tokens') or 0
|
||||
model_dump = getattr(value, 'model_dump', None)
|
||||
if callable(model_dump):
|
||||
try:
|
||||
dumped = model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return _plain_value(dumped)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
def _usage_dict(value: typing.Any) -> dict[str, typing.Any]:
|
||||
if value is None:
|
||||
return {}
|
||||
plain = _plain_value(value)
|
||||
if isinstance(plain, dict):
|
||||
return plain
|
||||
|
||||
def _is_mock_attr(attr: typing.Any) -> bool:
|
||||
return type(attr).__module__.startswith('unittest.mock')
|
||||
|
||||
data: dict[str, typing.Any] = {}
|
||||
for key in (
|
||||
'prompt_tokens',
|
||||
'completion_tokens',
|
||||
'total_tokens',
|
||||
'prompt_tokens_details',
|
||||
'completion_tokens_details',
|
||||
'cache_creation_input_tokens',
|
||||
'cache_read_input_tokens',
|
||||
'input_token_details',
|
||||
'output_token_details',
|
||||
):
|
||||
attr_value = getattr(value, key, None)
|
||||
if attr_value is not None and not _is_mock_attr(attr_value):
|
||||
data[key] = _plain_value(attr_value)
|
||||
return data
|
||||
|
||||
def _to_int(value: typing.Any) -> int:
|
||||
try:
|
||||
return int(value or 0)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
normalized = _usage_dict(usage)
|
||||
|
||||
prompt_tokens = _to_int(normalized.get('prompt_tokens'))
|
||||
completion_tokens = _to_int(normalized.get('completion_tokens'))
|
||||
total_tokens = _to_int(normalized.get('total_tokens'))
|
||||
|
||||
# Some providers omit total_tokens in streaming usage; derive it.
|
||||
if not total_tokens:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return {
|
||||
'prompt_tokens': int(prompt_tokens),
|
||||
'completion_tokens': int(completion_tokens),
|
||||
'total_tokens': int(total_tokens),
|
||||
}
|
||||
normalized['prompt_tokens'] = prompt_tokens
|
||||
normalized['completion_tokens'] = completion_tokens
|
||||
normalized['total_tokens'] = total_tokens
|
||||
return normalized
|
||||
|
||||
def _extract_usage(self, response) -> dict:
|
||||
def _extract_usage(self, response) -> dict | None:
|
||||
"""Extract usage info from a non-streaming LiteLLM response."""
|
||||
return self._normalize_usage(getattr(response, 'usage', None))
|
||||
usage = getattr(response, 'usage', None)
|
||||
if usage is None:
|
||||
return None
|
||||
return self._normalize_usage(usage)
|
||||
|
||||
@staticmethod
|
||||
def _as_dict(value: typing.Any) -> dict:
|
||||
@@ -474,7 +536,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
if query is not None:
|
||||
if query.variables is None:
|
||||
query.variables = {}
|
||||
query.variables['_stream_usage'] = usage_info
|
||||
query.variables[requester.STREAM_USAGE_QUERY_VARIABLE] = usage_info
|
||||
|
||||
if not hasattr(chunk, 'choices') or not chunk.choices:
|
||||
continue
|
||||
|
||||
@@ -117,6 +117,7 @@ class SkillToolLoader(loader.ToolLoader):
|
||||
'activated': True,
|
||||
'skill_name': skill_name,
|
||||
'mount_path': mount_path,
|
||||
'activated_skill_names': skill_loader.get_activated_skill_names(query),
|
||||
'content': result_content,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user