mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-16 18:56:02 +00:00
refactor(provider): use LiteLLM as unified LLM requester backend (#2150)
* refactor(provider): use LiteLLM as unified LLM requester backend
- Replace 23+ individual requester implementations with unified litellmchat.py
- Add litellm_provider field to 27 YAML manifests for provider routing
- Delete redundant requester subclasses
- Add unit tests for LiteLLMRequester (29 tests)
- Fix num_retries parameter name (was max_retries)
- Fix exception handling order for subclass exceptions
LiteLLM provides unified API for 100+ providers, eliminating need for
provider-specific requesters.
* fix: ruff format provider.py
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
* refactor(provider): simplify LiteLLM requester usage handling
- Remove unused Anthropic-specific tool schema generation
- Share completion argument construction between normal and streaming calls
- Use LiteLLM/OpenAI native usage fields for monitoring
- Collect stream token usage from LiteLLM stream_options
- Update LiteLLM requester tests for unified usage fields
* restore: restore deleted provider requester files
Restore individual provider requester implementations that were
removed in de61b5d3. These files coexist with the unified
litellmchat.py backend.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
* feat: update requesters and improve provider selection UI
- Added `litellm_provider` field to various requesters' YAML configurations.
- Removed obsolete Python requester files for OpenRouter, PPIO, QHAIGC, ShengSuanYun, SiliconFlow, Space, TokenPony, VolcArk, and Xai.
- Introduced new requesters for Tencent and Together AI with corresponding YAML configurations and SVG icons.
- Enhanced the ProviderForm component to include a searchable dropdown for selecting providers, improving user experience.
- Updated localization files to include search provider text for both English and Chinese.
* fix(provider): align litellm rebase with master
* fix(provider): capture streaming token usage; add token observability
The LiteLLM streaming requester only captured usage when a chunk had an
empty `choices` list. Many OpenAI-compatible gateways (e.g. new-api) and
providers send the final usage payload in a chunk that still carries an
empty-delta choice, so streamed calls always recorded 0 tokens in the
monitoring logs/dashboard (non-streaming worked).
- Capture stream usage whenever a chunk carries it, regardless of choices
- Add robust _normalize_usage (dict/obj shapes, derive missing total_tokens)
- Register litellm in bootutils/deps.py (was in pyproject only)
- Add MonitoringService.get_token_statistics + /monitoring/token-statistics
endpoint: summary, per-model breakdown, token timeseries, and a
zero-token-success data-quality signal
- Add TokenMonitoring dashboard tab (summary tiles, stacked token chart,
per-model table) + i18n (en/zh)
- Regression tests for stream usage capture and usage normalization
Verified end-to-end against a real OpenAI-compatible endpoint with
gpt-5.5 and claude-opus-4-8: tokens now recorded non-zero for both
streaming and non-streaming paths.
* refactor(provider): simplify litellm capabilities
* style: simplify wrapped expressions
* feat(models): persist context metadata
* fix(provider): handle dict embeddings and openai-compatible rerank in LiteLLMRequester
- invoke_embedding: support both object- and dict-shaped response.data
entries (OpenAI-compatible gateways like new-api return dicts)
- invoke_rerank: litellm.arerank rejects the 'openai' provider, so for
openai-compatible (or unspecified) providers call the standard
Jina/Cohere-style POST /v1/rerank endpoint directly over HTTP
- accept both 'relevance_score' and 'score' fields in rerank results
- add unit tests for the openai-compatible HTTP rerank path
* feat(provider): enforce requester support_type when adding models
- frontend: AddModelPopover only shows model-type tabs (llm/embedding/
rerank) that the provider's requester declares in its manifest
support_type; ModelsDialog fetches requester manifests and maps
requester -> support_type, passed down through ProviderCard
- backend: add _validate_provider_supports guard in create_llm_model /
create_embedding_model / create_rerank_model so a model cannot be
attached to a provider whose requester does not support that type,
even if the frontend restriction is bypassed (manifests without
support_type are allowed for backward compatibility)
- manifests: correct support_type for providers that do not offer all
three model types:
- llm only: anthropic, deepseek, groq, moonshot, openrouter, xai
- llm + text-embedding: openai, gemini, mistral
- add rerank to new-api (verified working via /v1/rerank)
- set llm + text-embedding + rerank for aggregator/unknown gateways
* feat(provider): add searchable alias to requester manifests
- add a free-text 'alias' field to every requester manifest spec,
containing the vendor's English/Chinese names, pinyin, common
nicknames and flagship model-series names (e.g. moonshot -> kimi,
月之暗面; zhipu -> glm, 智谱清言)
- frontend: ProviderForm requester search now also matches against
alias (substring/contains), so searching 'kimi' surfaces Moonshot,
'硅基' surfaces SiliconFlow, etc.
- also fix support_type: openrouter (relay) supports embedding+rerank;
LangBot Space gains rerank (coming soon)
* fix(provider): make support_type guard defensive against incomplete model_mgr
- _validate_provider_supports now uses getattr to gracefully skip when
model_mgr / provider_dict / manifest lookup is unavailable, instead of
raising AttributeError (fixes unit tests that mock ap.model_mgr as a
bare SimpleNamespace)
- add TestValidateProviderSupports covering: allow supported type,
reject unsupported type, allow when support_type missing, allow when
provider unknown, degrade safely when model_mgr is incomplete
* fix(persistence): guard 0004 migration against missing llm_models table
The 0004_add_llm_model_context_length migration called
inspector.get_columns('llm_models') unconditionally, raising
NoSuchTableError when the table does not exist (e.g. migrating a
fresh/empty DB, as exercised by the integration tests where
create_all() registers no tables because the ORM models are not
imported). Every other migration guards with a table-existence check
first; add the same guard here for both upgrade and downgrade.
Also restore the test head assertion to 0004 (it had been lowered to
0003 to mask this failure).
* Merge branch 'master' into feat/litellm
Resolve conflicts:
- uv.lock: regenerated via 'uv lock' to reconcile litellm/fastuuid
(ours) with openai bump (master).
- Alembic migrations: master added 0004_add_mcp_readme while this
branch added 0004_add_llm_model_context_length, both as children of
0003 (would create multiple heads). Re-chain the litellm migration as
0005_add_llm_model_context_length with down_revision=0004_add_mcp_readme
for a single linear head. Update test head assertion accordingly.
* fix(persistence): shorten migration revision id to fit varchar(32)
PostgreSQL stores alembic_version.version_num as varchar(32).
'0005_add_llm_model_context_length' (33 chars) overflowed it, raising
StringDataRightTruncationError in the PG migration tests. Rename the
revision (and file) to '0005_add_llm_context_length' (27 chars) and
update the head assertions in both SQLite and PostgreSQL migration
tests.
---------
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: fdc310 <2213070223@qq.com>
Co-authored-by: RockChinQ <rockchinq@gmail.com>
This commit is contained in:
@@ -104,7 +104,7 @@ class TestSQLiteMigrationUpgrade:
|
||||
rev = await get_alembic_current(sqlite_engine)
|
||||
assert rev is not None, "Expected a revision after upgrade"
|
||||
# Head should be the latest migration
|
||||
assert rev.startswith('0004'), f"Expected head to be 0004_*, got {rev}"
|
||||
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upgrade_idempotent(self, sqlite_engine):
|
||||
|
||||
@@ -150,8 +150,8 @@ class TestPostgreSQLMigrationUpgrade:
|
||||
# Verify revision
|
||||
rev = await get_alembic_current(postgres_engine)
|
||||
assert rev is not None, "Expected a revision after upgrade"
|
||||
# Head should be the latest migration (0004 for current state)
|
||||
assert rev.startswith('0004'), f"Expected head to be 0004_*, got {rev}"
|
||||
# Head should be the latest migration (0005 for current state)
|
||||
assert rev.startswith('0005'), f"Expected head to be 0005_*, got {rev}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_upgrade_idempotent(
|
||||
|
||||
@@ -23,6 +23,7 @@ from langbot.pkg.api.http.service.model import (
|
||||
RerankModelsService,
|
||||
_parse_provider_api_keys,
|
||||
_runtime_model_data,
|
||||
_validate_provider_supports,
|
||||
)
|
||||
from langbot.pkg.entity.persistence.model import LLMModel, EmbeddingModel, RerankModel, ModelProvider
|
||||
|
||||
@@ -35,6 +36,7 @@ def _create_mock_llm_model(
|
||||
name: str = 'Test LLM',
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
abilities: list = None,
|
||||
context_length: int | None = None,
|
||||
extra_args: dict = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock LLMModel entity."""
|
||||
@@ -43,6 +45,7 @@ def _create_mock_llm_model(
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
model.abilities = abilities or []
|
||||
model.context_length = context_length
|
||||
model.extra_args = extra_args or {}
|
||||
return model
|
||||
|
||||
@@ -142,10 +145,12 @@ class TestRuntimeModelData:
|
||||
'name': 'Model',
|
||||
'provider_uuid': 'provider',
|
||||
'abilities': ['vision'],
|
||||
'context_length': 128000,
|
||||
'extra_args': {'temp': 0.7},
|
||||
}
|
||||
result = _runtime_model_data('uuid', update_payload)
|
||||
assert result['abilities'] == ['vision']
|
||||
assert result['context_length'] == 128000
|
||||
assert result['extra_args'] == {'temp': 0.7}
|
||||
|
||||
|
||||
@@ -188,7 +193,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model()
|
||||
model = _create_mock_llm_model(context_length=128000)
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
@@ -206,6 +211,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
|
||||
'context_length': getattr(entity, 'context_length', None),
|
||||
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
|
||||
}
|
||||
)
|
||||
@@ -218,6 +224,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0]['name'] == 'Test LLM'
|
||||
assert result[0]['context_length'] == 128000
|
||||
|
||||
async def test_get_llm_models_hide_secret_keys(self):
|
||||
"""Hides secret API keys when include_secret=False."""
|
||||
@@ -265,7 +272,7 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model(model_uuid='found-uuid')
|
||||
model = _create_mock_llm_model(model_uuid='found-uuid', context_length=128000)
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([], first_item=model)
|
||||
@@ -279,11 +286,12 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-uuid',
|
||||
'name': 'Test LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']},
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': getattr(entity, 'provider_uuid', None),
|
||||
'context_length': getattr(entity, 'context_length', None),
|
||||
'api_keys': getattr(entity, 'api_keys', None),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -295,6 +303,7 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'found-uuid'
|
||||
assert result['context_length'] == 128000
|
||||
|
||||
async def test_get_llm_model_not_found(self):
|
||||
"""Returns None when model not found."""
|
||||
@@ -402,6 +411,39 @@ class TestLLMModelsServiceCreateLLMModel:
|
||||
# Verify
|
||||
assert model_uuid == 'preserved-uuid'
|
||||
|
||||
async def test_create_llm_model_persists_context_length_as_column(self):
|
||||
"""Creates LLM model with context_length outside extra_args."""
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock())
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
await service.create_llm_model(
|
||||
{
|
||||
'uuid': 'model-with-context',
|
||||
'name': 'Context Model',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': ['func_call'],
|
||||
'context_length': 128000,
|
||||
'extra_args': {'temperature': 0.2},
|
||||
},
|
||||
preserve_uuid=True,
|
||||
auto_set_to_default_pipeline=False,
|
||||
)
|
||||
|
||||
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
|
||||
assert runtime_entity.context_length == 128000
|
||||
assert runtime_entity.extra_args == {'temperature': 0.2}
|
||||
assert 'context_length' not in runtime_entity.extra_args
|
||||
|
||||
async def test_create_llm_model_provider_not_found_raises_error(self):
|
||||
"""Raises Exception when provider not found in runtime."""
|
||||
# Setup
|
||||
@@ -512,6 +554,35 @@ class TestLLMModelsServiceUpdateLLMModel:
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
})
|
||||
|
||||
async def test_update_llm_model_reloads_context_length_as_column(self):
|
||||
"""Updates runtime model with context_length outside extra_args."""
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.remove_llm_model = AsyncMock()
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
await service.update_llm_model(
|
||||
'existing-uuid',
|
||||
{
|
||||
'name': 'Updated Name',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': ['vision'],
|
||||
'context_length': 64000,
|
||||
'extra_args': {'temperature': 0.4},
|
||||
},
|
||||
)
|
||||
|
||||
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
|
||||
assert runtime_entity.uuid == 'existing-uuid'
|
||||
assert runtime_entity.context_length == 64000
|
||||
assert runtime_entity.extra_args == {'temperature': 0.4}
|
||||
assert 'context_length' not in runtime_entity.extra_args
|
||||
|
||||
|
||||
class TestLLMModelsServiceDeleteLLMModel:
|
||||
"""Tests for LLMModelsService.delete_llm_model method."""
|
||||
@@ -961,4 +1032,56 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
|
||||
result = await service.get_rerank_models_by_provider('provider-uuid')
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestValidateProviderSupports:
|
||||
"""Tests for _validate_provider_supports guard."""
|
||||
|
||||
@staticmethod
|
||||
def _make_ap(requester_name: str, support_type):
|
||||
"""Build a fake ap whose model_mgr resolves a manifest with support_type."""
|
||||
manifest = SimpleNamespace(spec={'support_type': support_type})
|
||||
runtime_provider = SimpleNamespace(
|
||||
provider_entity=SimpleNamespace(requester=requester_name)
|
||||
)
|
||||
model_mgr = SimpleNamespace(
|
||||
provider_dict={'p1': runtime_provider},
|
||||
get_available_requester_manifest_by_name=lambda name: manifest
|
||||
if name == requester_name
|
||||
else None,
|
||||
)
|
||||
return SimpleNamespace(model_mgr=model_mgr)
|
||||
|
||||
async def test_allows_supported_type(self):
|
||||
ap = self._make_ap('cohere-rerank', ['rerank'])
|
||||
# Should not raise
|
||||
await _validate_provider_supports(ap, 'p1', 'rerank')
|
||||
|
||||
async def test_rejects_unsupported_type(self):
|
||||
ap = self._make_ap('cohere-rerank', ['rerank'])
|
||||
with pytest.raises(ValueError, match='does not support llm'):
|
||||
await _validate_provider_supports(ap, 'p1', 'llm')
|
||||
|
||||
async def test_allows_when_support_type_missing(self):
|
||||
# Manifest without support_type must not block (backward compatible)
|
||||
manifest = SimpleNamespace(spec={})
|
||||
runtime_provider = SimpleNamespace(
|
||||
provider_entity=SimpleNamespace(requester='legacy')
|
||||
)
|
||||
model_mgr = SimpleNamespace(
|
||||
provider_dict={'p1': runtime_provider},
|
||||
get_available_requester_manifest_by_name=lambda name: manifest,
|
||||
)
|
||||
ap = SimpleNamespace(model_mgr=model_mgr)
|
||||
await _validate_provider_supports(ap, 'p1', 'rerank')
|
||||
|
||||
async def test_allows_when_provider_unknown(self):
|
||||
ap = self._make_ap('cohere-rerank', ['rerank'])
|
||||
# Unknown provider uuid -> no entry -> no block
|
||||
await _validate_provider_supports(ap, 'missing', 'llm')
|
||||
|
||||
async def test_degrades_when_model_mgr_incomplete(self):
|
||||
# A bare ap without a usable model_mgr must not raise (defensive)
|
||||
ap = SimpleNamespace(model_mgr=SimpleNamespace())
|
||||
await _validate_provider_supports(ap, 'p1', 'llm')
|
||||
|
||||
@@ -1 +1 @@
|
||||
|
||||
"""Provider requester tests"""
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Tests for AnthropicMessages requester.
|
||||
|
||||
Tests config and pure utility methods.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class TestAnthropicMessagesConfig:
|
||||
"""Tests for default config."""
|
||||
|
||||
def test_default_config_values(self):
|
||||
"""Check default_config."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.anthropicmsgs import AnthropicMessages
|
||||
|
||||
assert AnthropicMessages.default_config['base_url'] == 'https://api.anthropic.com'
|
||||
assert AnthropicMessages.default_config['timeout'] == 120
|
||||
|
||||
def test_config_override(self):
|
||||
"""Config can override defaults."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.anthropicmsgs import AnthropicMessages
|
||||
|
||||
mock_app = MagicMock()
|
||||
req = AnthropicMessages(mock_app, {
|
||||
'base_url': 'https://custom.anthropic.com',
|
||||
'timeout': 60,
|
||||
})
|
||||
|
||||
assert req.requester_cfg['base_url'] == 'https://custom.anthropic.com'
|
||||
assert req.requester_cfg['timeout'] == 60
|
||||
@@ -1,247 +0,0 @@
|
||||
"""Tests for requester error handling - direct import version.
|
||||
|
||||
Tests error handling branches by importing real packages and mocking
|
||||
only the necessary dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
import openai # Import real openai package
|
||||
|
||||
from langbot.pkg.provider.modelmgr.errors import RequesterError
|
||||
|
||||
|
||||
class TestInvokeLLMErrorHandling:
|
||||
"""Tests for invoke_llm error handling branches."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create mock Application."""
|
||||
app = MagicMock()
|
||||
app.tool_mgr = MagicMock()
|
||||
app.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=[])
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self):
|
||||
"""Create mock RuntimeLLMModel."""
|
||||
model = MagicMock()
|
||||
model.model_entity = MagicMock()
|
||||
model.model_entity.name = 'gpt-4'
|
||||
model.provider = MagicMock()
|
||||
model.provider.token_mgr = MagicMock()
|
||||
model.provider.token_mgr.get_token = MagicMock(return_value='test-key')
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create mock provider message."""
|
||||
msg = MagicMock()
|
||||
msg.dict = MagicMock(return_value={'role': 'user', 'content': 'test'})
|
||||
return msg
|
||||
|
||||
@pytest.fixture
|
||||
def requester_with_mocked_client(self, mock_app):
|
||||
"""Create requester with mocked OpenAI client."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
req = OpenAIChatCompletions(mock_app, {
|
||||
'base_url': 'https://api.openai.com/v1',
|
||||
'timeout': 120,
|
||||
})
|
||||
|
||||
# Replace client with mock
|
||||
req.client = MagicMock()
|
||||
req.client.chat = MagicMock()
|
||||
req.client.chat.completions = MagicMock()
|
||||
req.client.chat.completions.create = AsyncMock()
|
||||
|
||||
return req
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_error(self, requester_with_mocked_client, mock_model, mock_message):
|
||||
"""TimeoutError is wrapped as RequesterError."""
|
||||
requester_with_mocked_client.client.chat.completions.create = AsyncMock(
|
||||
side_effect=asyncio.TimeoutError()
|
||||
)
|
||||
|
||||
with pytest.raises(RequesterError) as exc:
|
||||
await requester_with_mocked_client.invoke_llm(
|
||||
query=None,
|
||||
model=mock_model,
|
||||
messages=[mock_message],
|
||||
)
|
||||
|
||||
assert '超时' in str(exc.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_request_context_length(self, requester_with_mocked_client, mock_model, mock_message):
|
||||
"""BadRequestError with context_length_exceeded has special message."""
|
||||
error = openai.BadRequestError(
|
||||
message='context_length_exceeded: max 4096',
|
||||
response=MagicMock(status_code=400),
|
||||
body={}
|
||||
)
|
||||
requester_with_mocked_client.client.chat.completions.create = AsyncMock(
|
||||
side_effect=error
|
||||
)
|
||||
|
||||
with pytest.raises(RequesterError) as exc:
|
||||
await requester_with_mocked_client.invoke_llm(
|
||||
query=None,
|
||||
model=mock_model,
|
||||
messages=[mock_message],
|
||||
)
|
||||
|
||||
assert '上文过长' in str(exc.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_error(self, requester_with_mocked_client, mock_model, mock_message):
|
||||
"""AuthenticationError shows invalid api-key message."""
|
||||
error = openai.AuthenticationError(
|
||||
message='Invalid API key',
|
||||
response=MagicMock(status_code=401),
|
||||
body={}
|
||||
)
|
||||
requester_with_mocked_client.client.chat.completions.create = AsyncMock(
|
||||
side_effect=error
|
||||
)
|
||||
|
||||
with pytest.raises(RequesterError) as exc:
|
||||
await requester_with_mocked_client.invoke_llm(
|
||||
query=None,
|
||||
model=mock_model,
|
||||
messages=[mock_message],
|
||||
)
|
||||
|
||||
assert 'api-key' in str(exc.value).lower() or '无效' in str(exc.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_error(self, requester_with_mocked_client, mock_model, mock_message):
|
||||
"""RateLimitError shows rate limit message."""
|
||||
error = openai.RateLimitError(
|
||||
message='Rate limit exceeded',
|
||||
response=MagicMock(status_code=429),
|
||||
body={}
|
||||
)
|
||||
requester_with_mocked_client.client.chat.completions.create = AsyncMock(
|
||||
side_effect=error
|
||||
)
|
||||
|
||||
with pytest.raises(RequesterError) as exc:
|
||||
await requester_with_mocked_client.invoke_llm(
|
||||
query=None,
|
||||
model=mock_model,
|
||||
messages=[mock_message],
|
||||
)
|
||||
|
||||
assert '频繁' in str(exc.value) or '余额' in str(exc.value)
|
||||
|
||||
|
||||
class TestInvokeEmbeddingErrorHandling:
|
||||
"""Tests for invoke_embedding error handling."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_model(self):
|
||||
model = MagicMock()
|
||||
model.model_entity = MagicMock()
|
||||
model.model_entity.name = 'text-embedding-ada-002'
|
||||
model.model_entity.extra_args = {}
|
||||
model.provider = MagicMock()
|
||||
model.provider.token_mgr = MagicMock()
|
||||
model.provider.token_mgr.get_token = MagicMock(return_value='test-key')
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def requester_with_mocked_client(self, mock_app):
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
req = OpenAIChatCompletions(mock_app, {})
|
||||
req.client = MagicMock()
|
||||
req.client.embeddings = MagicMock()
|
||||
req.client.embeddings.create = AsyncMock()
|
||||
|
||||
return req
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_timeout_error(self, requester_with_mocked_client, mock_embedding_model):
|
||||
"""TimeoutError in embedding request."""
|
||||
requester_with_mocked_client.client.embeddings.create = AsyncMock(
|
||||
side_effect=asyncio.TimeoutError()
|
||||
)
|
||||
|
||||
with pytest.raises(RequesterError) as exc:
|
||||
await requester_with_mocked_client.invoke_embedding(
|
||||
model=mock_embedding_model,
|
||||
input_text=['test'],
|
||||
)
|
||||
|
||||
assert '超时' in str(exc.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_bad_request_error(self, requester_with_mocked_client, mock_embedding_model):
|
||||
"""BadRequestError in embedding request."""
|
||||
error = openai.BadRequestError(
|
||||
message='Invalid model',
|
||||
response=MagicMock(status_code=400),
|
||||
body={}
|
||||
)
|
||||
requester_with_mocked_client.client.embeddings.create = AsyncMock(
|
||||
side_effect=error
|
||||
)
|
||||
|
||||
with pytest.raises(RequesterError) as exc:
|
||||
await requester_with_mocked_client.invoke_embedding(
|
||||
model=mock_embedding_model,
|
||||
input_text=['test'],
|
||||
)
|
||||
|
||||
assert '参数' in str(exc.value)
|
||||
|
||||
|
||||
class TestRequesterErrorClass:
|
||||
"""Tests for RequesterError."""
|
||||
|
||||
def test_error_message_prefix(self):
|
||||
"""RequesterError has '模型请求失败' prefix."""
|
||||
from langbot.pkg.provider.modelmgr.errors import RequesterError
|
||||
|
||||
error = RequesterError('test error')
|
||||
assert '模型请求失败' in str(error)
|
||||
|
||||
def test_error_is_exception(self):
|
||||
"""RequesterError inherits Exception."""
|
||||
from langbot.pkg.provider.modelmgr.errors import RequesterError
|
||||
|
||||
error = RequesterError('test')
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
class TestDefaultConfig:
|
||||
"""Tests for requester default config."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Check default_config values."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
assert OpenAIChatCompletions.default_config['base_url'] == 'https://api.openai.com/v1'
|
||||
assert OpenAIChatCompletions.default_config['timeout'] == 120
|
||||
|
||||
def test_config_override(self):
|
||||
"""Config overrides defaults."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
req = OpenAIChatCompletions(MagicMock(), {
|
||||
'base_url': 'https://custom.com/v1',
|
||||
'timeout': 60,
|
||||
})
|
||||
|
||||
assert req.requester_cfg['base_url'] == 'https://custom.com/v1'
|
||||
assert req.requester_cfg['timeout'] == 60
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Tests for requester pure utility functions.
|
||||
|
||||
Tests the helper methods in OpenAIChatCompletions that don't require network calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tests.utils.import_isolation import isolated_sys_modules
|
||||
|
||||
|
||||
class TestMaskApiKey:
|
||||
"""Tests for _mask_api_key method."""
|
||||
|
||||
def _create_requester_with_mocks(self):
|
||||
"""Create requester instance with mocked dependencies."""
|
||||
mocks = {
|
||||
'langbot.pkg.core.app': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.provider.message': MagicMock(),
|
||||
'langbot.pkg.provider.modelmgr.errors': MagicMock(),
|
||||
}
|
||||
|
||||
with isolated_sys_modules(mocks):
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
mock_app = MagicMock()
|
||||
requester = OpenAIChatCompletions(mock_app, {})
|
||||
return requester
|
||||
|
||||
def test_mask_api_key_full(self):
|
||||
"""Mask a full API key."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._mask_api_key('sk-1234567890abcdef')
|
||||
assert result == 'sk-1...cdef'
|
||||
|
||||
def test_mask_api_key_short(self):
|
||||
"""Mask a short API key (<=8 chars)."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._mask_api_key('short')
|
||||
assert result == '****'
|
||||
|
||||
def test_mask_api_key_empty(self):
|
||||
"""Empty API key returns empty string."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._mask_api_key('')
|
||||
assert result == ''
|
||||
|
||||
def test_mask_api_key_none(self):
|
||||
"""None API key returns empty string."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._mask_api_key(None)
|
||||
assert result == ''
|
||||
|
||||
def test_mask_api_key_exact_8_chars(self):
|
||||
"""API key with exactly 8 chars is masked as **** (<=8 threshold)."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._mask_api_key('12345678')
|
||||
assert result == '****' # <= 8 chars gets masked
|
||||
|
||||
|
||||
class TestInferModelType:
|
||||
"""Tests for _infer_model_type method."""
|
||||
|
||||
def _create_requester_with_mocks(self):
|
||||
mocks = {
|
||||
'langbot.pkg.core.app': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.provider.message': MagicMock(),
|
||||
'langbot.pkg.provider.modelmgr.errors': MagicMock(),
|
||||
}
|
||||
|
||||
with isolated_sys_modules(mocks):
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
mock_app = MagicMock()
|
||||
requester = OpenAIChatCompletions(mock_app, {})
|
||||
return requester
|
||||
|
||||
def test_infer_embedding_from_name(self):
|
||||
"""Infer embedding type from model name."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
assert requester._infer_model_type('text-embedding-ada-002') == 'embedding'
|
||||
assert requester._infer_model_type('bge-large-en') == 'embedding'
|
||||
assert requester._infer_model_type('e5-base') == 'embedding'
|
||||
assert requester._infer_model_type('m3e-base') == 'embedding'
|
||||
|
||||
def test_infer_llm_from_name(self):
|
||||
"""Infer LLM type from model name."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
assert requester._infer_model_type('gpt-4') == 'llm'
|
||||
assert requester._infer_model_type('claude-3-opus') == 'llm'
|
||||
assert requester._infer_model_type('llama-2-70b') == 'llm'
|
||||
|
||||
def test_infer_model_type_none_id(self):
|
||||
"""Handle None model_id."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._infer_model_type(None)
|
||||
assert result == 'llm' # Default
|
||||
|
||||
def test_infer_model_type_empty_id(self):
|
||||
"""Handle empty model_id."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._infer_model_type('')
|
||||
assert result == 'llm' # Default
|
||||
|
||||
|
||||
class TestNormalizeModalities:
|
||||
"""Tests for _normalize_modalities method."""
|
||||
|
||||
def _create_requester_with_mocks(self):
|
||||
mocks = {
|
||||
'langbot.pkg.core.app': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.provider.message': MagicMock(),
|
||||
'langbot.pkg.provider.modelmgr.errors': MagicMock(),
|
||||
}
|
||||
|
||||
with isolated_sys_modules(mocks):
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
mock_app = MagicMock()
|
||||
requester = OpenAIChatCompletions(mock_app, {})
|
||||
return requester
|
||||
|
||||
def test_normalize_string_modality(self):
|
||||
"""Normalize single string modality."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._normalize_modalities('text,image')
|
||||
assert result == ['text', 'image']
|
||||
|
||||
def test_normalize_list_modalities(self):
|
||||
"""Normalize list of modalities."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._normalize_modalities(['text', 'image', 'audio'])
|
||||
assert result == ['text', 'image', 'audio']
|
||||
|
||||
def test_normalize_dict_modalities(self):
|
||||
"""Normalize dict with nested modalities."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._normalize_modalities({'input': ['text'], 'output': ['text', 'image']})
|
||||
assert result == ['text', 'image']
|
||||
|
||||
def test_normalize_none(self):
|
||||
"""Handle None input."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._normalize_modalities(None)
|
||||
assert result == []
|
||||
|
||||
def test_normalize_arrow_separator(self):
|
||||
"""Handle arrow separator in modality string."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
result = requester._normalize_modalities('text->image')
|
||||
assert result == ['text', 'image']
|
||||
|
||||
|
||||
class TestParseRerankResponse:
|
||||
"""Tests for _parse_rerank_response static method."""
|
||||
|
||||
def test_parse_cohere_jina_format(self):
|
||||
"""Parse Cohere/Jina/SiliconFlow format."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
data = {
|
||||
'results': [
|
||||
{'index': 0, 'relevance_score': 0.95},
|
||||
{'index': 1, 'relevance_score': 0.80},
|
||||
]
|
||||
}
|
||||
|
||||
result = OpenAIChatCompletions._parse_rerank_response(data)
|
||||
assert result == [
|
||||
{'index': 0, 'relevance_score': 0.95},
|
||||
{'index': 1, 'relevance_score': 0.80},
|
||||
]
|
||||
|
||||
def test_parse_voyage_format(self):
|
||||
"""Parse Voyage AI format."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
data = {
|
||||
'data': [
|
||||
{'index': 0, 'relevance_score': 0.90},
|
||||
{'index': 2, 'relevance_score': 0.75},
|
||||
]
|
||||
}
|
||||
|
||||
result = OpenAIChatCompletions._parse_rerank_response(data)
|
||||
assert result == [
|
||||
{'index': 0, 'relevance_score': 0.90},
|
||||
{'index': 2, 'relevance_score': 0.75},
|
||||
]
|
||||
|
||||
def test_parse_dashscope_format(self):
|
||||
"""Parse DashScope format."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
data = {
|
||||
'output': {
|
||||
'results': [
|
||||
{'index': 0, 'relevance_score': 0.85},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
result = OpenAIChatCompletions._parse_rerank_response(data)
|
||||
assert result == [{'index': 0, 'relevance_score': 0.85}]
|
||||
|
||||
def test_parse_unknown_format(self):
|
||||
"""Handle unknown format returns empty list."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
data = {'unknown_key': 'value'}
|
||||
|
||||
result = OpenAIChatCompletions._parse_rerank_response(data)
|
||||
assert result == []
|
||||
|
||||
def test_parse_empty_results(self):
|
||||
"""Handle empty results."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
data = {'results': []}
|
||||
|
||||
result = OpenAIChatCompletions._parse_rerank_response(data)
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestExtractScanMetadata:
|
||||
"""Tests for _extract_scan_metadata method."""
|
||||
|
||||
def _create_requester_with_mocks(self):
|
||||
mocks = {
|
||||
'langbot.pkg.core.app': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.resource.tool': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.pipeline.query': MagicMock(),
|
||||
'langbot_plugin.api.entities.builtin.provider.message': MagicMock(),
|
||||
'langbot.pkg.provider.modelmgr.errors': MagicMock(),
|
||||
}
|
||||
|
||||
with isolated_sys_modules(mocks):
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
|
||||
mock_app = MagicMock()
|
||||
requester = OpenAIChatCompletions(mock_app, {})
|
||||
return requester
|
||||
|
||||
def test_extract_basic_metadata(self):
|
||||
"""Extract basic model metadata."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
item = {
|
||||
'id': 'gpt-4',
|
||||
'name': 'GPT-4 Turbo',
|
||||
'description': 'Most capable GPT-4 model',
|
||||
'context_length': 128000,
|
||||
'owned_by': 'openai',
|
||||
}
|
||||
|
||||
result = requester._extract_scan_metadata(item, 'gpt-4')
|
||||
|
||||
assert result['display_name'] == 'GPT-4 Turbo'
|
||||
assert result['description'] == 'Most capable GPT-4 model'
|
||||
assert result['context_length'] == 128000
|
||||
assert result['owned_by'] == 'openai'
|
||||
|
||||
def test_extract_metadata_missing_fields(self):
|
||||
"""Handle missing metadata fields."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
item = {'id': 'unknown-model'}
|
||||
|
||||
result = requester._extract_scan_metadata(item, 'unknown-model')
|
||||
|
||||
assert result['display_name'] is None
|
||||
assert result['description'] is None
|
||||
assert result['context_length'] is None
|
||||
assert result['owned_by'] is None
|
||||
|
||||
def test_extract_metadata_top_provider_context(self):
|
||||
"""Extract context_length from top_provider."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
item = {
|
||||
'id': 'model',
|
||||
'top_provider': {
|
||||
'context_length': 4096,
|
||||
},
|
||||
}
|
||||
|
||||
result = requester._extract_scan_metadata(item, 'model')
|
||||
|
||||
assert result['context_length'] == 4096
|
||||
|
||||
def test_extract_metadata_empty_strings(self):
|
||||
"""Handle empty string values."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
item = {
|
||||
'id': 'model',
|
||||
'name': '', # Empty name
|
||||
'description': ' ', # Whitespace only
|
||||
'owned_by': '',
|
||||
}
|
||||
|
||||
result = requester._extract_scan_metadata(item, 'model')
|
||||
|
||||
assert result['display_name'] is None
|
||||
assert result['description'] is None
|
||||
assert result['owned_by'] is None
|
||||
|
||||
def test_extract_metadata_name_matches_id(self):
|
||||
"""When name equals id, display_name is None."""
|
||||
requester = self._create_requester_with_mocks()
|
||||
|
||||
item = {
|
||||
'id': 'gpt-4',
|
||||
'name': 'gpt-4', # Same as id
|
||||
}
|
||||
|
||||
result = requester._extract_scan_metadata(item, 'gpt-4')
|
||||
|
||||
assert result['display_name'] is None
|
||||
@@ -1,264 +0,0 @@
|
||||
"""Tests for OllamaChatCompletions requester.
|
||||
|
||||
Tests model inference, payload construction, and error handling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from langbot.pkg.provider.modelmgr.errors import RequesterError
|
||||
|
||||
|
||||
class TestOllamaRequesterConfig:
|
||||
"""Tests for default config."""
|
||||
|
||||
def test_default_config_values(self):
|
||||
"""Check default_config."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
|
||||
|
||||
assert OllamaChatCompletions.default_config['base_url'] == 'http://127.0.0.1:11434'
|
||||
assert OllamaChatCompletions.default_config['timeout'] == 120
|
||||
|
||||
def test_config_override(self):
|
||||
"""Config can override defaults."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
|
||||
|
||||
mock_app = MagicMock()
|
||||
req = OllamaChatCompletions(mock_app, {
|
||||
'base_url': 'http://custom.ollama:11434',
|
||||
'timeout': 300,
|
||||
})
|
||||
|
||||
assert req.requester_cfg['base_url'] == 'http://custom.ollama:11434'
|
||||
assert req.requester_cfg['timeout'] == 300
|
||||
|
||||
|
||||
class TestOllamaInferModelType:
|
||||
"""Tests for _infer_model_type pure function."""
|
||||
|
||||
@pytest.fixture
|
||||
def requester(self):
|
||||
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
|
||||
|
||||
return OllamaChatCompletions(MagicMock(), {})
|
||||
|
||||
def test_infer_embedding_from_name(self, requester):
|
||||
"""Embedding keywords return 'embedding'."""
|
||||
assert requester._infer_model_type('nomic-embed-text') == 'embedding'
|
||||
assert requester._infer_model_type('bge-large') == 'embedding'
|
||||
assert requester._infer_model_type('text-embedding') == 'embedding'
|
||||
|
||||
def test_infer_llm_from_name(self, requester):
|
||||
"""Non-embedding keywords return 'llm'."""
|
||||
assert requester._infer_model_type('llama2') == 'llm'
|
||||
assert requester._infer_model_type('mistral') == 'llm'
|
||||
assert requester._infer_model_type('codellama') == 'llm'
|
||||
|
||||
def test_infer_model_type_none(self, requester):
|
||||
"""None model_id returns 'llm'."""
|
||||
assert requester._infer_model_type(None) == 'llm'
|
||||
|
||||
def test_infer_model_type_empty(self, requester):
|
||||
"""Empty model_id returns 'llm'."""
|
||||
assert requester._infer_model_type('') == 'llm'
|
||||
|
||||
|
||||
class TestOllamaInferModelAbilities:
|
||||
"""Tests for _infer_model_abilities pure function."""
|
||||
|
||||
@pytest.fixture
|
||||
def requester(self):
|
||||
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
|
||||
|
||||
return OllamaChatCompletions(MagicMock(), {})
|
||||
|
||||
def test_infer_vision_ability(self, requester):
|
||||
"""Vision keywords add 'vision' ability."""
|
||||
item = {
|
||||
'details': {
|
||||
'family': 'llava',
|
||||
}
|
||||
}
|
||||
|
||||
abilities = requester._infer_model_abilities(item, 'llava-v1.5')
|
||||
assert 'vision' in abilities
|
||||
|
||||
def test_infer_vision_from_model_id(self, requester):
|
||||
"""Vision keywords in model_id add 'vision' ability."""
|
||||
item = {}
|
||||
abilities = requester._infer_model_abilities(item, 'llava-7b')
|
||||
assert 'vision' in abilities
|
||||
|
||||
def test_infer_func_call_ability(self, requester):
|
||||
"""Tool/function keywords add 'func_call' ability."""
|
||||
item = {
|
||||
'details': {
|
||||
'families': ['tools'],
|
||||
}
|
||||
}
|
||||
|
||||
abilities = requester._infer_model_abilities(item, 'model')
|
||||
assert 'func_call' in abilities
|
||||
|
||||
def test_infer_no_abilities(self, requester):
|
||||
"""No matching keywords returns empty abilities."""
|
||||
item = {
|
||||
'details': {
|
||||
'family': 'llama',
|
||||
}
|
||||
}
|
||||
|
||||
abilities = requester._infer_model_abilities(item, 'llama-2')
|
||||
assert len(abilities) == 0
|
||||
|
||||
def test_infer_multiple_abilities(self, requester):
|
||||
"""Multiple keywords can add multiple abilities."""
|
||||
item = {
|
||||
'details': {
|
||||
'family': 'vision',
|
||||
'families': ['tools'],
|
||||
}
|
||||
}
|
||||
|
||||
abilities = requester._infer_model_abilities(item, 'vision-tool-model')
|
||||
assert 'vision' in abilities
|
||||
assert 'func_call' in abilities
|
||||
|
||||
|
||||
class TestOllamaMakeMessage:
|
||||
"""Tests for _make_msg response parsing."""
|
||||
|
||||
@pytest.fixture
|
||||
def requester(self):
|
||||
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
|
||||
|
||||
return OllamaChatCompletions(MagicMock(), {})
|
||||
|
||||
def _create_ollama_response(self, content, tool_calls=None):
|
||||
"""Helper to create mock ollama response."""
|
||||
import ollama
|
||||
|
||||
mock_response = MagicMock(spec=ollama.ChatResponse)
|
||||
mock_message = MagicMock(spec=ollama.Message)
|
||||
mock_message.content = content
|
||||
mock_message.tool_calls = tool_calls
|
||||
mock_response.message = mock_message
|
||||
|
||||
return mock_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_msg_text_content(self, requester):
|
||||
"""Text content is extracted."""
|
||||
mock_response = self._create_ollama_response('Hello world')
|
||||
|
||||
result = await requester._make_msg(mock_response)
|
||||
|
||||
assert result.content == 'Hello world'
|
||||
assert result.role == 'assistant'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_msg_with_tool_calls(self, requester):
|
||||
"""Tool calls are parsed."""
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.function = MagicMock()
|
||||
mock_tool_call.function.name = 'get_weather'
|
||||
mock_tool_call.function.arguments = {'location': 'Beijing'}
|
||||
|
||||
mock_response = self._create_ollama_response('', tool_calls=[mock_tool_call])
|
||||
|
||||
result = await requester._make_msg(mock_response)
|
||||
|
||||
assert result.tool_calls is not None
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == 'get_weather'
|
||||
# Arguments should be JSON string
|
||||
assert isinstance(result.tool_calls[0].function.arguments, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_msg_empty_message_raises(self, requester):
|
||||
"""Empty message raises ValueError."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.message = None
|
||||
|
||||
with pytest.raises(ValueError, match='message'):
|
||||
await requester._make_msg(mock_response)
|
||||
|
||||
|
||||
class TestOllamaErrorHandling:
|
||||
"""Tests for error handling branches."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
app = MagicMock()
|
||||
app.tool_mgr = MagicMock()
|
||||
app.tool_mgr.generate_tools_for_openai = AsyncMock(return_value=[])
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def requester_with_mocked_client(self, mock_app):
|
||||
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
|
||||
|
||||
req = OllamaChatCompletions(mock_app, {})
|
||||
req.client = MagicMock()
|
||||
req.client.chat = AsyncMock()
|
||||
|
||||
return req
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self):
|
||||
model = MagicMock()
|
||||
model.model_entity = MagicMock()
|
||||
model.model_entity.name = 'llama2'
|
||||
model.provider = MagicMock()
|
||||
model.provider.token_mgr = MagicMock()
|
||||
model.provider.token_mgr.get_token = MagicMock(return_value='')
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
msg = MagicMock()
|
||||
msg.role = 'user'
|
||||
msg.content = 'test'
|
||||
msg.dict = MagicMock(return_value={'role': 'user', 'content': 'test'})
|
||||
return msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_error(self, requester_with_mocked_client, mock_model, mock_message):
|
||||
"""TimeoutError is converted to RequesterError."""
|
||||
requester_with_mocked_client.client.chat = AsyncMock(side_effect=asyncio.TimeoutError())
|
||||
|
||||
with pytest.raises(RequesterError) as exc:
|
||||
await requester_with_mocked_client.invoke_llm(
|
||||
query=None,
|
||||
model=mock_model,
|
||||
messages=[mock_message],
|
||||
)
|
||||
|
||||
assert '超时' in str(exc.value)
|
||||
|
||||
|
||||
class TestOllamaScanModels:
|
||||
"""Tests for scan_models method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def requester(self, mock_app):
|
||||
from langbot.pkg.provider.modelmgr.requesters.ollamachat import OllamaChatCompletions
|
||||
|
||||
req = OllamaChatCompletions(mock_app, {
|
||||
'base_url': 'http://127.0.0.1:11434',
|
||||
'timeout': 120,
|
||||
})
|
||||
return req
|
||||
|
||||
def test_requester_name_constant(self):
|
||||
"""REQUESTER_NAME constant exists."""
|
||||
from langbot.pkg.provider.modelmgr.requesters.ollamachat import REQUESTER_NAME
|
||||
|
||||
assert REQUESTER_NAME == 'ollama-chat'
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner, _StreamAccumulator
|
||||
|
||||
|
||||
class RecordingProvider:
|
||||
@@ -124,6 +124,45 @@ def make_query() -> pipeline_query.Query:
|
||||
)
|
||||
|
||||
|
||||
def test_stream_accumulator_merges_fragmented_tool_call_arguments():
|
||||
accumulator = _StreamAccumulator(msg_sequence=1)
|
||||
|
||||
assert (
|
||||
accumulator.add(
|
||||
provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(name='exec', arguments='{"command":'),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
emitted = accumulator.add(
|
||||
provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
provider_message.ToolCall(
|
||||
id='call-1',
|
||||
type='function',
|
||||
function=provider_message.FunctionCall(name='exec', arguments='"pwd"}'),
|
||||
)
|
||||
],
|
||||
is_final=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert emitted is not None
|
||||
final_msg = accumulator.final_message()
|
||||
assert final_msg.tool_calls[0].function.name == 'exec'
|
||||
assert final_msg.tool_calls[0].function.arguments == '{"command":"pwd"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localagent_uses_exec_for_exact_calculation():
|
||||
provider = RecordingProvider()
|
||||
|
||||
@@ -494,6 +494,7 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
|
||||
'api_keys': ['temp-key'],
|
||||
},
|
||||
'abilities': ['func_call'],
|
||||
'context_length': 128000,
|
||||
'extra_args': {'temperature': 0.5},
|
||||
}
|
||||
|
||||
@@ -501,6 +502,9 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
|
||||
|
||||
assert runtime_model.model_entity.uuid == 'temp-model-uuid'
|
||||
assert runtime_model.model_entity.name == 'TempModel'
|
||||
assert runtime_model.model_entity.context_length == 128000
|
||||
assert runtime_model.model_entity.extra_args == {'temperature': 0.5}
|
||||
assert 'context_length' not in runtime_model.model_entity.extra_args
|
||||
assert runtime_model.provider.provider_entity.uuid == 'temp-provider-uuid'
|
||||
assert runtime_model.provider.token_mgr.tokens == ['temp-key']
|
||||
|
||||
@@ -785,4 +789,4 @@ def test_provider_not_found_error_str():
|
||||
error = provider_errors.ProviderNotFoundError('test-provider')
|
||||
|
||||
assert str(error) == 'Provider test-provider not found'
|
||||
assert error.provider_name == 'test-provider'
|
||||
assert error.provider_name == 'test-provider'
|
||||
|
||||
@@ -16,8 +16,6 @@ from langbot.pkg.entity.persistence import model as persistence_model
|
||||
from langbot.pkg.pipeline.preproc.preproc import PreProcessor
|
||||
from langbot.pkg.provider.modelmgr import requester
|
||||
from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
from langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl import ModelScopeChatCompletions
|
||||
from langbot.pkg.provider.modelmgr.token import TokenManager
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
||||
|
||||
@@ -90,74 +88,6 @@ def test_token_manager_next_token_ignores_empty_token_list():
|
||||
assert token_mgr.using_token_index == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_requester_initialize_uses_placeholder_api_key(monkeypatch):
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_client(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.chatcmpl.openai.AsyncClient', fake_client)
|
||||
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.chatcmpl.httpx.AsyncClient', fake_client)
|
||||
|
||||
requester_inst = OpenAIChatCompletions(ap=SimpleNamespace(), config={})
|
||||
await requester_inst.initialize()
|
||||
|
||||
assert captured_kwargs['api_key'] == OpenAIChatCompletions.init_api_key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modelscope_requester_initialize_uses_placeholder_api_key(monkeypatch):
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_client(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.openai.AsyncClient', fake_client)
|
||||
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.httpx.AsyncClient', fake_client)
|
||||
|
||||
requester_inst = ModelScopeChatCompletions(ap=SimpleNamespace(), config={})
|
||||
await requester_inst.initialize()
|
||||
|
||||
assert captured_kwargs['api_key'] == ModelScopeChatCompletions.init_api_key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding_call_overrides_placeholder_api_key():
|
||||
captured_request = {}
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_request['api_key'] = fake_client.api_key
|
||||
captured_request['kwargs'] = kwargs
|
||||
return SimpleNamespace(
|
||||
data=[SimpleNamespace(embedding=[0.1, 0.2])],
|
||||
usage=SimpleNamespace(prompt_tokens=3, total_tokens=3),
|
||||
)
|
||||
|
||||
fake_client = SimpleNamespace(
|
||||
api_key=OpenAIChatCompletions.init_api_key,
|
||||
embeddings=SimpleNamespace(create=fake_create),
|
||||
)
|
||||
|
||||
requester_inst = OpenAIChatCompletions(ap=SimpleNamespace(), config={})
|
||||
requester_inst.client = fake_client
|
||||
|
||||
embeddings, usage_info = await requester_inst.invoke_embedding(
|
||||
model=requester.RuntimeEmbeddingModel(
|
||||
model_entity=SimpleNamespace(name='text-embedding-3-small', extra_args={}),
|
||||
provider=SimpleNamespace(token_mgr=TokenManager('provider-uuid', [' runtime-key ', '', 'runtime-key'])),
|
||||
),
|
||||
input_text=['hello'],
|
||||
)
|
||||
|
||||
assert captured_request['api_key'] == 'runtime-key'
|
||||
assert captured_request['kwargs']['model'] == 'text-embedding-3-small'
|
||||
assert embeddings == [[0.1, 0.2]]
|
||||
assert usage_info == {'prompt_tokens': 3, 'total_tokens': 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline():
|
||||
from langbot.pkg.api.http.service.model import LLMModelsService
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for ToolManager.
|
||||
|
||||
Tests cover:
|
||||
- Tool schema generation for OpenAI and Anthropic
|
||||
- Tool schema generation for OpenAI/LiteLLM
|
||||
- Tool execution dispatch
|
||||
"""
|
||||
|
||||
@@ -109,28 +109,6 @@ class TestToolManagerSchemaGeneration:
|
||||
assert tool2['type'] == 'function'
|
||||
assert tool2['function']['name'] == 'calculate'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_tools_for_anthropic(self, mock_app, sample_tools):
|
||||
"""Test that generate_tools_for_anthropic produces correct schema."""
|
||||
toolmgr = get_toolmgr_module()
|
||||
|
||||
manager = toolmgr.ToolManager(mock_app)
|
||||
result = await manager.generate_tools_for_anthropic(sample_tools)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify first tool schema (Anthropic format)
|
||||
tool1 = result[0]
|
||||
assert tool1['name'] == 'get_weather'
|
||||
assert tool1['description'] == 'Get current weather for a location'
|
||||
assert 'input_schema' in tool1
|
||||
assert tool1['input_schema']['type'] == 'object'
|
||||
|
||||
# Verify second tool schema
|
||||
tool2 = result[1]
|
||||
assert tool2['name'] == 'calculate'
|
||||
assert 'input_schema' in tool2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_tools_empty_list(self, mock_app):
|
||||
"""Test that generating tools from empty list returns empty list."""
|
||||
@@ -141,9 +119,6 @@ class TestToolManagerSchemaGeneration:
|
||||
openai_result = await manager.generate_tools_for_openai([])
|
||||
assert openai_result == []
|
||||
|
||||
anthropic_result = await manager.generate_tools_for_anthropic([])
|
||||
assert anthropic_result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_schema_fields_complete(self, mock_app, sample_tools):
|
||||
"""Test that OpenAI schema includes all required fields."""
|
||||
@@ -161,20 +136,6 @@ class TestToolManagerSchemaGeneration:
|
||||
assert 'description' in func
|
||||
assert 'parameters' in func
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_schema_fields_complete(self, mock_app, sample_tools):
|
||||
"""Test that Anthropic schema includes all required fields."""
|
||||
toolmgr = get_toolmgr_module()
|
||||
|
||||
manager = toolmgr.ToolManager(mock_app)
|
||||
result = await manager.generate_tools_for_anthropic(sample_tools)
|
||||
|
||||
for tool_schema in result:
|
||||
assert 'name' in tool_schema
|
||||
assert 'description' in tool_schema
|
||||
assert 'input_schema' in tool_schema
|
||||
|
||||
|
||||
class TestToolManagerExecuteFuncCall:
|
||||
"""Tests for execute_func_call method."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user