mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-08 14:56:03 +00:00
feat(models): persist context metadata
This commit is contained in:
@@ -31,6 +31,7 @@ class LLMModel(Base):
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
|
||||
context_length = sqlalchemy.Column(sqlalchemy.Integer, nullable=True)
|
||||
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add llm model context length
|
||||
|
||||
Revision ID: 0004_add_llm_model_context_length
|
||||
Revises: 0003_add_rerank_models
|
||||
Create Date: 2026-06-07
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '0004_add_llm_model_context_length'
|
||||
down_revision = '0003_add_rerank_models'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = {column['name'] for column in inspector.get_columns('llm_models')}
|
||||
if 'context_length' not in columns:
|
||||
op.add_column('llm_models', sa.Column('context_length', sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
columns = {column['name'] for column in inspector.get_columns('llm_models')}
|
||||
if 'context_length' in columns:
|
||||
op.drop_column('llm_models', 'context_length')
|
||||
@@ -0,0 +1,42 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(26)
|
||||
class DBMigrateLLMModelContextLength(migration.DBMigration):
|
||||
"""Add context_length column to LLM models"""
|
||||
|
||||
async def upgrade(self):
|
||||
columns = await self._get_columns('llm_models')
|
||||
if 'context_length' not in columns:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN context_length INTEGER')
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
columns = await self._get_columns('llm_models')
|
||||
if 'context_length' not in columns:
|
||||
return
|
||||
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE llm_models DROP COLUMN IF EXISTS context_length')
|
||||
)
|
||||
else:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text('ALTER TABLE llm_models DROP COLUMN context_length')
|
||||
)
|
||||
|
||||
async def _get_columns(self, table_name: str) -> set[str]:
|
||||
if self.ap.persistence_mgr.db.name == 'postgresql':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = :table_name
|
||||
"""),
|
||||
{'table_name': table_name},
|
||||
)
|
||||
return {row[0] for row in result.fetchall()}
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name})'))
|
||||
return {row[1] for row in result.fetchall()}
|
||||
@@ -266,6 +266,7 @@ class ModelManager:
|
||||
name=model_info.get('name', ''),
|
||||
provider_uuid='',
|
||||
abilities=model_info.get('abilities', []),
|
||||
context_length=model_info.get('context_length'),
|
||||
extra_args=model_info.get('extra_args', {}),
|
||||
),
|
||||
provider=runtime_provider,
|
||||
@@ -460,6 +461,7 @@ class ModelManager:
|
||||
name=model_info.get('name', ''),
|
||||
provider_uuid=model_info.get('provider_uuid', ''),
|
||||
abilities=model_info.get('abilities', []),
|
||||
context_length=model_info.get('context_length'),
|
||||
extra_args=model_info.get('extra_args', {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -59,8 +59,15 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
litellm_model_name = self._build_litellm_model_name(model_name)
|
||||
if litellm_model_name != model_name:
|
||||
candidates.append((litellm_model_name, None))
|
||||
for metadata_provider in self._metadata_provider_candidates(model_name):
|
||||
candidates.append((f'{metadata_provider}/{model_name}', None))
|
||||
|
||||
tried_candidates: set[tuple[str, str | None]] = set()
|
||||
for candidate_model, candidate_provider in candidates:
|
||||
candidate_key = (candidate_model, candidate_provider)
|
||||
if candidate_key in tried_candidates:
|
||||
continue
|
||||
tried_candidates.add(candidate_key)
|
||||
try:
|
||||
if bool(helper(model=candidate_model, custom_llm_provider=candidate_provider)):
|
||||
return True
|
||||
@@ -68,24 +75,80 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
continue
|
||||
return False
|
||||
|
||||
def _context_length_from_scan_payload(self, model_payload: dict[str, typing.Any] | None) -> int | None:
|
||||
if not model_payload:
|
||||
return None
|
||||
|
||||
for field_name in ('context_length', 'context_window', 'max_context_length'):
|
||||
value = model_payload.get(field_name)
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int) and value > 0:
|
||||
return value
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
parsed_value = int(value)
|
||||
if parsed_value > 0:
|
||||
return parsed_value
|
||||
return None
|
||||
|
||||
def _metadata_provider_candidates(self, model_name: str) -> list[str]:
|
||||
normalized_model_name = (model_name or '').lower()
|
||||
candidates = []
|
||||
if normalized_model_name.startswith(('moonshot-', 'kimi-')):
|
||||
candidates.append('moonshot')
|
||||
if normalized_model_name.startswith('deepseek-'):
|
||||
candidates.append('deepseek')
|
||||
|
||||
base_url = self.requester_cfg.get('base_url', '').lower()
|
||||
if 'moonshot' in base_url:
|
||||
candidates.append('moonshot')
|
||||
if 'deepseek' in base_url:
|
||||
candidates.append('deepseek')
|
||||
|
||||
deduped_candidates = []
|
||||
for candidate in candidates:
|
||||
if candidate not in deduped_candidates:
|
||||
deduped_candidates.append(candidate)
|
||||
return deduped_candidates
|
||||
|
||||
def _known_context_length_fallback(self, model_name: str) -> int | None:
|
||||
normalized_model_name = (model_name or '').lower()
|
||||
if normalized_model_name.startswith('deepseek-v4-'):
|
||||
return 1_000_000
|
||||
if normalized_model_name.startswith(('kimi-k2.5', 'kimi-k2.6')):
|
||||
return 256 * 1024
|
||||
if normalized_model_name.startswith('moonshot-v1-8k'):
|
||||
return 8 * 1024
|
||||
if normalized_model_name.startswith('moonshot-v1-32k'):
|
||||
return 32 * 1024
|
||||
if normalized_model_name.startswith('moonshot-v1-128k') or normalized_model_name == 'moonshot-v1-auto':
|
||||
return 128 * 1024
|
||||
return None
|
||||
|
||||
def _safe_context_length(self, model_name: str) -> int | None:
|
||||
helper = getattr(litellm, 'get_max_tokens', None)
|
||||
if not callable(helper):
|
||||
return None
|
||||
return self._known_context_length_fallback(model_name)
|
||||
|
||||
candidates = [model_name]
|
||||
litellm_model_name = self._build_litellm_model_name(model_name)
|
||||
if litellm_model_name != model_name:
|
||||
candidates.append(litellm_model_name)
|
||||
for provider in self._metadata_provider_candidates(model_name):
|
||||
candidates.append(f'{provider}/{model_name}')
|
||||
|
||||
tried_candidates = []
|
||||
for candidate in candidates:
|
||||
if candidate in tried_candidates:
|
||||
continue
|
||||
tried_candidates.append(candidate)
|
||||
try:
|
||||
max_tokens = helper(candidate)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(max_tokens, int) and max_tokens > 0:
|
||||
return max_tokens
|
||||
return None
|
||||
return self._known_context_length_fallback(model_name)
|
||||
|
||||
def _supports_function_calling(self, model_name: str) -> bool:
|
||||
return self._safe_litellm_bool_helper('supports_function_calling', model_name)
|
||||
@@ -101,7 +164,11 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
return 'embedding'
|
||||
return 'llm'
|
||||
|
||||
def _enrich_scanned_model(self, model_id: str) -> dict[str, typing.Any]:
|
||||
def _enrich_scanned_model(
|
||||
self,
|
||||
model_id: str,
|
||||
model_payload: dict[str, typing.Any] | None = None,
|
||||
) -> dict[str, typing.Any]:
|
||||
model_type = self._infer_model_type(model_id)
|
||||
scanned_model: dict[str, typing.Any] = {
|
||||
'id': model_id,
|
||||
@@ -113,11 +180,17 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
abilities = []
|
||||
if self._supports_function_calling(model_id):
|
||||
abilities.append('func_call')
|
||||
if self._supports_vision(model_id):
|
||||
supports_provider_reported_vision = bool(
|
||||
model_payload
|
||||
and (model_payload.get('supports_image_in') is True or model_payload.get('supports_vision') is True)
|
||||
)
|
||||
if supports_provider_reported_vision or self._supports_vision(model_id):
|
||||
abilities.append('vision')
|
||||
scanned_model['abilities'] = abilities
|
||||
|
||||
context_length = self._safe_context_length(model_id)
|
||||
context_length = self._context_length_from_scan_payload(model_payload)
|
||||
if context_length is None:
|
||||
context_length = self._safe_context_length(model_id)
|
||||
if context_length is not None:
|
||||
scanned_model['context_length'] = context_length
|
||||
|
||||
@@ -557,7 +630,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
models.append(self._enrich_scanned_model(model_id))
|
||||
models.append(self._enrich_scanned_model(model_id, item))
|
||||
|
||||
models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower()))
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import langbot
|
||||
|
||||
semantic_version = f'v{langbot.__version__}'
|
||||
|
||||
required_database_version = 25
|
||||
required_database_version = 26
|
||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||
|
||||
debug_mode = False
|
||||
|
||||
@@ -35,6 +35,7 @@ def _create_mock_llm_model(
|
||||
name: str = 'Test LLM',
|
||||
provider_uuid: str = 'provider-uuid',
|
||||
abilities: list = None,
|
||||
context_length: int | None = None,
|
||||
extra_args: dict = None,
|
||||
) -> Mock:
|
||||
"""Helper to create mock LLMModel entity."""
|
||||
@@ -43,6 +44,7 @@ def _create_mock_llm_model(
|
||||
model.name = name
|
||||
model.provider_uuid = provider_uuid
|
||||
model.abilities = abilities or []
|
||||
model.context_length = context_length
|
||||
model.extra_args = extra_args or {}
|
||||
return model
|
||||
|
||||
@@ -142,10 +144,12 @@ class TestRuntimeModelData:
|
||||
'name': 'Model',
|
||||
'provider_uuid': 'provider',
|
||||
'abilities': ['vision'],
|
||||
'context_length': 128000,
|
||||
'extra_args': {'temp': 0.7},
|
||||
}
|
||||
result = _runtime_model_data('uuid', update_payload)
|
||||
assert result['abilities'] == ['vision']
|
||||
assert result['context_length'] == 128000
|
||||
assert result['extra_args'] == {'temp': 0.7}
|
||||
|
||||
|
||||
@@ -188,7 +192,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model()
|
||||
model = _create_mock_llm_model(context_length=128000)
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([model])
|
||||
@@ -206,6 +210,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
|
||||
'context_length': getattr(entity, 'context_length', None),
|
||||
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
|
||||
}
|
||||
)
|
||||
@@ -218,6 +223,7 @@ class TestLLMModelsServiceGetLLMModels:
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0]['name'] == 'Test LLM'
|
||||
assert result[0]['context_length'] == 128000
|
||||
|
||||
async def test_get_llm_models_hide_secret_keys(self):
|
||||
"""Hides secret API keys when include_secret=False."""
|
||||
@@ -265,7 +271,7 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
|
||||
model = _create_mock_llm_model(model_uuid='found-uuid')
|
||||
model = _create_mock_llm_model(model_uuid='found-uuid', context_length=128000)
|
||||
provider = _create_mock_provider()
|
||||
|
||||
mock_model_result = _create_mock_result([], first_item=model)
|
||||
@@ -279,11 +285,12 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
|
||||
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
|
||||
ap.persistence_mgr.serialize_model = Mock(
|
||||
return_value={
|
||||
'uuid': 'found-uuid',
|
||||
'name': 'Test LLM',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']},
|
||||
side_effect=lambda model_cls, entity: {
|
||||
'uuid': entity.uuid,
|
||||
'name': entity.name,
|
||||
'provider_uuid': getattr(entity, 'provider_uuid', None),
|
||||
'context_length': getattr(entity, 'context_length', None),
|
||||
'api_keys': getattr(entity, 'api_keys', None),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -295,6 +302,7 @@ class TestLLMModelsServiceGetLLMModel:
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result['uuid'] == 'found-uuid'
|
||||
assert result['context_length'] == 128000
|
||||
|
||||
async def test_get_llm_model_not_found(self):
|
||||
"""Returns None when model not found."""
|
||||
@@ -402,6 +410,39 @@ class TestLLMModelsServiceCreateLLMModel:
|
||||
# Verify
|
||||
assert model_uuid == 'preserved-uuid'
|
||||
|
||||
async def test_create_llm_model_persists_context_length_as_column(self):
|
||||
"""Creates LLM model with context_length outside extra_args."""
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace()
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock())
|
||||
|
||||
mock_result = _create_mock_result([])
|
||||
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
await service.create_llm_model(
|
||||
{
|
||||
'uuid': 'model-with-context',
|
||||
'name': 'Context Model',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': ['func_call'],
|
||||
'context_length': 128000,
|
||||
'extra_args': {'temperature': 0.2},
|
||||
},
|
||||
preserve_uuid=True,
|
||||
auto_set_to_default_pipeline=False,
|
||||
)
|
||||
|
||||
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
|
||||
assert runtime_entity.context_length == 128000
|
||||
assert runtime_entity.extra_args == {'temperature': 0.2}
|
||||
assert 'context_length' not in runtime_entity.extra_args
|
||||
|
||||
async def test_create_llm_model_provider_not_found_raises_error(self):
|
||||
"""Raises Exception when provider not found in runtime."""
|
||||
# Setup
|
||||
@@ -512,6 +553,35 @@ class TestLLMModelsServiceUpdateLLMModel:
|
||||
'provider_uuid': 'nonexistent-provider',
|
||||
})
|
||||
|
||||
async def test_update_llm_model_reloads_context_length_as_column(self):
|
||||
"""Updates runtime model with context_length outside extra_args."""
|
||||
ap = SimpleNamespace()
|
||||
ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
|
||||
ap.model_mgr = SimpleNamespace()
|
||||
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
|
||||
ap.model_mgr.llm_models = []
|
||||
ap.model_mgr.remove_llm_model = AsyncMock()
|
||||
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
|
||||
|
||||
service = LLMModelsService(ap)
|
||||
|
||||
await service.update_llm_model(
|
||||
'existing-uuid',
|
||||
{
|
||||
'name': 'Updated Name',
|
||||
'provider_uuid': 'provider-uuid',
|
||||
'abilities': ['vision'],
|
||||
'context_length': 64000,
|
||||
'extra_args': {'temperature': 0.4},
|
||||
},
|
||||
)
|
||||
|
||||
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
|
||||
assert runtime_entity.uuid == 'existing-uuid'
|
||||
assert runtime_entity.context_length == 64000
|
||||
assert runtime_entity.extra_args == {'temperature': 0.4}
|
||||
assert 'context_length' not in runtime_entity.extra_args
|
||||
|
||||
|
||||
class TestLLMModelsServiceDeleteLLMModel:
|
||||
"""Tests for LLMModelsService.delete_llm_model method."""
|
||||
@@ -961,4 +1031,4 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
|
||||
result = await service.get_rerank_models_by_provider('provider-uuid')
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert len(result) == 2
|
||||
|
||||
@@ -896,6 +896,121 @@ class TestScanModels:
|
||||
assert by_id['text-embedding-3-small']['type'] == 'embedding'
|
||||
assert by_id['bge-reranker-v2']['type'] == 'rerank'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_prefers_context_length_from_provider_payload(self):
|
||||
"""Provider-supplied context_length is preserved before LiteLLM metadata fallback."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'timeout': 60,
|
||||
},
|
||||
)
|
||||
requester._supports_function_calling = Mock(return_value=False)
|
||||
requester._supports_vision = Mock(return_value=False)
|
||||
requester._safe_context_length = Mock(return_value=None)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json = Mock(
|
||||
return_value={
|
||||
'data': [
|
||||
{'id': 'moonshot-v1-128k', 'context_length': 131072},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await requester.scan_models(api_key='test-key')
|
||||
|
||||
assert result['models'][0]['context_length'] == 131072
|
||||
requester._safe_context_length.assert_not_called()
|
||||
|
||||
def test_safe_context_length_tries_moonshot_metadata_alias(self):
|
||||
"""OpenAI-compatible Moonshot endpoints still use Moonshot metadata for context windows."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'custom_llm_provider': 'openai',
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(litellmchat.litellm, 'get_max_tokens') as mock_get_max_tokens:
|
||||
mock_get_max_tokens.side_effect = lambda model: 131072 if model == 'moonshot/moonshot-v1-128k' else None
|
||||
|
||||
assert requester._safe_context_length('moonshot-v1-128k') == 131072
|
||||
|
||||
def test_litellm_bool_helper_tries_moonshot_metadata_alias(self):
|
||||
"""OpenAI-compatible Moonshot endpoints still use Moonshot metadata for abilities."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'custom_llm_provider': 'openai',
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling:
|
||||
mock_supports_function_calling.side_effect = (
|
||||
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6'
|
||||
and custom_llm_provider is None
|
||||
)
|
||||
|
||||
assert requester._supports_function_calling('kimi-k2.6') is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_uses_provider_payload_for_vision_ability(self):
|
||||
"""Provider-supplied vision support is used when scanning models."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'timeout': 60,
|
||||
},
|
||||
)
|
||||
requester._supports_function_calling = Mock(return_value=False)
|
||||
requester._supports_vision = Mock(return_value=False)
|
||||
requester._safe_context_length = Mock(return_value=None)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json = Mock(
|
||||
return_value={
|
||||
'data': [
|
||||
{
|
||||
'id': 'moonshot-v1-128k-vision-preview',
|
||||
'supports_image_in': True,
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await requester.scan_models(api_key='test-key')
|
||||
|
||||
assert result['models'][0]['abilities'] == ['vision']
|
||||
|
||||
def test_safe_context_length_falls_back_for_deepseek_v4_models(self):
|
||||
"""DeepSeek V4 API ids have a known 1M context even before LiteLLM maps them."""
|
||||
requester = litellmchat.LiteLLMRequester(
|
||||
ap=Mock(),
|
||||
config={
|
||||
'base_url': 'https://api.deepseek.com',
|
||||
'custom_llm_provider': 'deepseek',
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(litellmchat.litellm, 'get_max_tokens', side_effect=Exception('not mapped')):
|
||||
assert requester._safe_context_length('deepseek-v4-pro') == 1_000_000
|
||||
assert requester._safe_context_length('deepseek-v4-flash') == 1_000_000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_models_no_base_url(self):
|
||||
"""Test scan_models without base_url raises error"""
|
||||
|
||||
@@ -494,6 +494,7 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
|
||||
'api_keys': ['temp-key'],
|
||||
},
|
||||
'abilities': ['func_call'],
|
||||
'context_length': 128000,
|
||||
'extra_args': {'temperature': 0.5},
|
||||
}
|
||||
|
||||
@@ -501,6 +502,9 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
|
||||
|
||||
assert runtime_model.model_entity.uuid == 'temp-model-uuid'
|
||||
assert runtime_model.model_entity.name == 'TempModel'
|
||||
assert runtime_model.model_entity.context_length == 128000
|
||||
assert runtime_model.model_entity.extra_args == {'temperature': 0.5}
|
||||
assert 'context_length' not in runtime_model.model_entity.extra_args
|
||||
assert runtime_model.provider.provider_entity.uuid == 'temp-provider-uuid'
|
||||
assert runtime_model.provider.token_mgr.tokens == ['temp-key']
|
||||
|
||||
@@ -785,4 +789,4 @@ def test_provider_not_found_error_str():
|
||||
error = provider_errors.ProviderNotFoundError('test-provider')
|
||||
|
||||
assert str(error) == 'Provider test-provider not found'
|
||||
assert error.provider_name == 'test-provider'
|
||||
assert error.provider_name == 'test-provider'
|
||||
|
||||
@@ -64,6 +64,17 @@ function convertExtraArgsToObject(
|
||||
return obj;
|
||||
}
|
||||
|
||||
function parseContextLength(
|
||||
value: number | null | undefined,
|
||||
invalidMessage: string,
|
||||
): number | null {
|
||||
if (value === undefined || value === null) return null;
|
||||
if (!Number.isInteger(value) || value <= 0) {
|
||||
throw new Error(invalidMessage);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
export default function ModelsDialog({
|
||||
open,
|
||||
onOpenChange,
|
||||
@@ -254,6 +265,7 @@ export default function ModelsDialog({
|
||||
name: string,
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
contextLength?: number | null,
|
||||
) {
|
||||
if (!name.trim()) {
|
||||
toast.error(t('models.modelNameRequired'));
|
||||
@@ -268,6 +280,10 @@ export default function ModelsDialog({
|
||||
name,
|
||||
provider_uuid: providerUuid,
|
||||
abilities,
|
||||
context_length: parseContextLength(
|
||||
contextLength,
|
||||
t('models.contextLengthInvalid'),
|
||||
),
|
||||
extra_args: extraArgsObj,
|
||||
} as never);
|
||||
} else if (modelType === 'embedding') {
|
||||
@@ -325,6 +341,7 @@ export default function ModelsDialog({
|
||||
name: item.model.name,
|
||||
provider_uuid: providerUuid,
|
||||
abilities: item.abilities,
|
||||
context_length: item.model.context_length ?? null,
|
||||
extra_args: {},
|
||||
} as never);
|
||||
} else if (effectiveType === 'embedding') {
|
||||
@@ -361,6 +378,7 @@ export default function ModelsDialog({
|
||||
name: string,
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
contextLength?: number | null,
|
||||
) {
|
||||
if (!name.trim()) {
|
||||
toast.error(t('models.modelNameRequired'));
|
||||
@@ -375,6 +393,10 @@ export default function ModelsDialog({
|
||||
name,
|
||||
provider_uuid: providerUuid,
|
||||
abilities,
|
||||
context_length: parseContextLength(
|
||||
contextLength,
|
||||
t('models.contextLengthInvalid'),
|
||||
),
|
||||
extra_args: extraArgsObj,
|
||||
} as never);
|
||||
} else if (modelType === 'embedding') {
|
||||
@@ -509,8 +531,15 @@ export default function ModelsDialog({
|
||||
onSpaceLogin={handleSpaceLogin}
|
||||
onOpenAddModel={() => setAddModelPopoverOpen(provider.uuid)}
|
||||
onCloseAddModel={() => setAddModelPopoverOpen(null)}
|
||||
onAddModel={(modelType, name, abilities, extraArgs) =>
|
||||
handleAddModel(provider.uuid, modelType, name, abilities, extraArgs)
|
||||
onAddModel={(modelType, name, abilities, extraArgs, contextLength) =>
|
||||
handleAddModel(
|
||||
provider.uuid,
|
||||
modelType,
|
||||
name,
|
||||
abilities,
|
||||
extraArgs,
|
||||
contextLength,
|
||||
)
|
||||
}
|
||||
onScanModels={(modelType) => handleScanModels(provider.uuid, modelType)}
|
||||
onAddScannedModels={(modelType, models) =>
|
||||
@@ -518,7 +547,14 @@ export default function ModelsDialog({
|
||||
}
|
||||
onOpenEditModel={(modelId) => setEditModelPopoverOpen(modelId)}
|
||||
onCloseEditModel={() => setEditModelPopoverOpen(null)}
|
||||
onUpdateModel={(modelId, modelType, name, abilities, extraArgs) =>
|
||||
onUpdateModel={(
|
||||
modelId,
|
||||
modelType,
|
||||
name,
|
||||
abilities,
|
||||
extraArgs,
|
||||
contextLength,
|
||||
) =>
|
||||
handleUpdateModel(
|
||||
provider.uuid,
|
||||
modelId,
|
||||
@@ -526,6 +562,7 @@ export default function ModelsDialog({
|
||||
name,
|
||||
abilities,
|
||||
extraArgs,
|
||||
contextLength,
|
||||
)
|
||||
}
|
||||
onOpenDeleteConfirm={(modelId) => setDeleteConfirmOpen(modelId)}
|
||||
|
||||
@@ -41,6 +41,7 @@ interface AddModelPopoverProps {
|
||||
name: string,
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
contextLength?: number | null,
|
||||
) => Promise<void>;
|
||||
onScanModels: (modelType?: ModelType) => Promise<ScanModelsResult>;
|
||||
onAddScannedModels: (
|
||||
@@ -81,6 +82,7 @@ export default function AddModelPopover({
|
||||
const [mode, setMode] = useState<'manual' | 'scan'>('manual');
|
||||
const [name, setName] = useState('');
|
||||
const [abilities, setAbilities] = useState<string[]>([]);
|
||||
const [contextLength, setContextLength] = useState('');
|
||||
const [extraArgs, setExtraArgs] = useState<ExtraArg[]>([]);
|
||||
const [scanLoading, setScanLoading] = useState(false);
|
||||
const [scannedModels, setScannedModels] = useState<ScannedProviderModel[]>(
|
||||
@@ -98,6 +100,7 @@ export default function AddModelPopover({
|
||||
setMode(initialMode);
|
||||
setName('');
|
||||
setAbilities([]);
|
||||
setContextLength('');
|
||||
setExtraArgs([]);
|
||||
setScanLoading(false);
|
||||
setScannedModels([]);
|
||||
@@ -119,7 +122,11 @@ export default function AddModelPopover({
|
||||
}, [tab, mode]);
|
||||
|
||||
const handleAdd = async () => {
|
||||
await onAddModel(tab, name, abilities, extraArgs);
|
||||
const parsedContextLength =
|
||||
tab === 'llm' && contextLength.trim()
|
||||
? Number(contextLength.trim())
|
||||
: null;
|
||||
await onAddModel(tab, name, abilities, extraArgs, parsedContextLength);
|
||||
};
|
||||
|
||||
const handleTest = async () => {
|
||||
@@ -318,6 +325,24 @@ export default function AddModelPopover({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{tab === 'llm' && (
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="add-context-length">
|
||||
{t('models.contextLength')}
|
||||
</Label>
|
||||
<Input
|
||||
id="add-context-length"
|
||||
type="number"
|
||||
min={1}
|
||||
step={1}
|
||||
inputMode="numeric"
|
||||
placeholder={t('models.contextLengthPlaceholder')}
|
||||
value={contextLength}
|
||||
onChange={(e) => setContextLength(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<ExtraArgsEditor
|
||||
args={extraArgs}
|
||||
onChange={setExtraArgs}
|
||||
|
||||
@@ -31,6 +31,7 @@ interface ModelItemProps {
|
||||
name: string,
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
contextLength?: number | null,
|
||||
) => Promise<void>;
|
||||
onTestModel: (
|
||||
name: string,
|
||||
@@ -92,6 +93,11 @@ export default function ModelItem({
|
||||
const [editAbilities, setEditAbilities] = useState<string[]>(
|
||||
modelType === 'llm' ? (model as LLMModel).abilities || [] : [],
|
||||
);
|
||||
const [editContextLength, setEditContextLength] = useState(
|
||||
modelType === 'llm' && (model as LLMModel).context_length
|
||||
? String((model as LLMModel).context_length)
|
||||
: '',
|
||||
);
|
||||
const [editExtraArgs, setEditExtraArgs] = useState<ExtraArg[]>(
|
||||
convertExtraArgsToArray(model.extra_args),
|
||||
);
|
||||
@@ -106,13 +112,27 @@ export default function ModelItem({
|
||||
setEditAbilities(
|
||||
modelType === 'llm' ? (model as LLMModel).abilities || [] : [],
|
||||
);
|
||||
setEditContextLength(
|
||||
modelType === 'llm' && (model as LLMModel).context_length
|
||||
? String((model as LLMModel).context_length)
|
||||
: '',
|
||||
);
|
||||
setEditExtraArgs(convertExtraArgsToArray(model.extra_args));
|
||||
onResetTestResult();
|
||||
}
|
||||
}, [isEditOpen]);
|
||||
|
||||
const handleSave = async () => {
|
||||
await onUpdateModel(editName, editAbilities, editExtraArgs);
|
||||
const parsedContextLength =
|
||||
modelType === 'llm' && editContextLength.trim()
|
||||
? Number(editContextLength.trim())
|
||||
: null;
|
||||
await onUpdateModel(
|
||||
editName,
|
||||
editAbilities,
|
||||
editExtraArgs,
|
||||
parsedContextLength,
|
||||
);
|
||||
};
|
||||
|
||||
const handleTest = async () => {
|
||||
@@ -287,6 +307,25 @@ export default function ModelItem({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{modelType === 'llm' && (
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor={`edit-context-length-${model.uuid}`}>
|
||||
{t('models.contextLength')}
|
||||
</Label>
|
||||
<Input
|
||||
id={`edit-context-length-${model.uuid}`}
|
||||
type="number"
|
||||
min={1}
|
||||
step={1}
|
||||
inputMode="numeric"
|
||||
placeholder={t('models.contextLengthPlaceholder')}
|
||||
value={editContextLength}
|
||||
disabled={isLangBotModels}
|
||||
onChange={(e) => setEditContextLength(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<ExtraArgsEditor
|
||||
args={editExtraArgs}
|
||||
onChange={setEditExtraArgs}
|
||||
|
||||
@@ -60,6 +60,7 @@ interface ProviderCardProps {
|
||||
name: string,
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
contextLength?: number | null,
|
||||
) => Promise<void>;
|
||||
onScanModels: (modelType?: ModelType) => Promise<ScanModelsResult>;
|
||||
onAddScannedModels: (
|
||||
@@ -74,6 +75,7 @@ interface ProviderCardProps {
|
||||
name: string,
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
contextLength?: number | null,
|
||||
) => Promise<void>;
|
||||
onOpenDeleteConfirm: (modelId: string) => void;
|
||||
onCloseDeleteConfirm: () => void;
|
||||
@@ -405,13 +407,19 @@ export default function ProviderCard({
|
||||
onOpenDeleteConfirm={onOpenDeleteConfirm}
|
||||
onCloseDeleteConfirm={onCloseDeleteConfirm}
|
||||
onDeleteModel={() => onDeleteModel(model.uuid, 'llm')}
|
||||
onUpdateModel={(name, abilities, extraArgs) =>
|
||||
onUpdateModel={(
|
||||
name,
|
||||
abilities,
|
||||
extraArgs,
|
||||
contextLength,
|
||||
) =>
|
||||
onUpdateModel(
|
||||
model.uuid,
|
||||
'llm',
|
||||
name,
|
||||
abilities,
|
||||
extraArgs,
|
||||
contextLength,
|
||||
)
|
||||
}
|
||||
onTestModel={(name, abilities, extraArgs) =>
|
||||
|
||||
@@ -53,6 +53,7 @@ export interface ModelItemProps {
|
||||
name: string,
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
contextLength?: number | null,
|
||||
) => Promise<void>;
|
||||
onTest: (
|
||||
name: string,
|
||||
@@ -89,6 +90,7 @@ export interface ProviderCardProps {
|
||||
name: string,
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
contextLength?: number | null,
|
||||
) => Promise<void>;
|
||||
onScanModels: (modelType?: ModelType) => Promise<ScanModelsResult>;
|
||||
onAddScannedModels: (
|
||||
@@ -103,6 +105,7 @@ export interface ProviderCardProps {
|
||||
name: string,
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
contextLength?: number | null,
|
||||
) => Promise<void>;
|
||||
onOpenDeleteConfirm: (modelId: string) => void;
|
||||
onCloseDeleteConfirm: () => void;
|
||||
|
||||
@@ -96,6 +96,7 @@ export interface LLMModel {
|
||||
provider_uuid: string;
|
||||
provider?: ModelProvider;
|
||||
abilities?: string[];
|
||||
context_length?: number | null;
|
||||
extra_args?: object;
|
||||
}
|
||||
|
||||
|
||||
@@ -201,6 +201,9 @@ const enUS = {
|
||||
selectModelAbilities: 'Select model abilities',
|
||||
visionAbility: 'Vision Ability',
|
||||
functionCallAbility: 'Function Call',
|
||||
contextLength: 'Context Window',
|
||||
contextLengthPlaceholder: 'Unknown',
|
||||
contextLengthInvalid: 'Context window must be a positive integer',
|
||||
extraParameters: 'Extra Parameters',
|
||||
addParameter: 'Add Parameter',
|
||||
keyName: 'Key Name',
|
||||
|
||||
@@ -206,6 +206,9 @@ const esES = {
|
||||
selectModelAbilities: 'Seleccionar capacidades del modelo',
|
||||
visionAbility: 'Capacidad de visión',
|
||||
functionCallAbility: 'Llamada a funciones',
|
||||
contextLength: 'Ventana de contexto',
|
||||
contextLengthPlaceholder: 'Desconocido',
|
||||
contextLengthInvalid: 'La ventana de contexto debe ser un entero positivo',
|
||||
extraParameters: 'Parámetros adicionales',
|
||||
addParameter: 'Añadir parámetro',
|
||||
keyName: 'Nombre de la clave',
|
||||
|
||||
@@ -204,6 +204,10 @@ const jaJP = {
|
||||
selectModelAbilities: 'モデル機能を選択',
|
||||
visionAbility: '視覚機能',
|
||||
functionCallAbility: '関数呼び出し',
|
||||
contextLength: 'コンテキストウィンドウ',
|
||||
contextLengthPlaceholder: '不明',
|
||||
contextLengthInvalid:
|
||||
'コンテキストウィンドウは正の整数である必要があります',
|
||||
extraParameters: '追加パラメータ',
|
||||
addParameter: 'パラメータを追加',
|
||||
keyName: 'キー名',
|
||||
|
||||
@@ -203,6 +203,10 @@ const ruRU = {
|
||||
selectModelAbilities: 'Выберите возможности модели',
|
||||
visionAbility: 'Распознавание изображений',
|
||||
functionCallAbility: 'Вызов функций',
|
||||
contextLength: 'Контекстное окно',
|
||||
contextLengthPlaceholder: 'Неизвестно',
|
||||
contextLengthInvalid:
|
||||
'Контекстное окно должно быть положительным целым числом',
|
||||
extraParameters: 'Дополнительные параметры',
|
||||
addParameter: 'Добавить параметр',
|
||||
keyName: 'Имя ключа',
|
||||
|
||||
@@ -199,6 +199,9 @@ const thTH = {
|
||||
selectModelAbilities: 'เลือกความสามารถของโมเดล',
|
||||
visionAbility: 'ความสามารถด้านภาพ',
|
||||
functionCallAbility: 'การเรียกฟังก์ชัน',
|
||||
contextLength: 'หน้าต่างบริบท',
|
||||
contextLengthPlaceholder: 'ไม่ทราบ',
|
||||
contextLengthInvalid: 'หน้าต่างบริบทต้องเป็นจำนวนเต็มบวก',
|
||||
extraParameters: 'พารามิเตอร์เพิ่มเติม',
|
||||
addParameter: 'เพิ่มพารามิเตอร์',
|
||||
keyName: 'ชื่อคีย์',
|
||||
|
||||
@@ -203,6 +203,9 @@ const viVN = {
|
||||
selectModelAbilities: 'Chọn khả năng mô hình',
|
||||
visionAbility: 'Khả năng thị giác',
|
||||
functionCallAbility: 'Gọi hàm',
|
||||
contextLength: 'Cửa sổ ngữ cảnh',
|
||||
contextLengthPlaceholder: 'Không rõ',
|
||||
contextLengthInvalid: 'Cửa sổ ngữ cảnh phải là số nguyên dương',
|
||||
extraParameters: 'Tham số bổ sung',
|
||||
addParameter: 'Thêm tham số',
|
||||
keyName: 'Tên khóa',
|
||||
|
||||
@@ -193,6 +193,9 @@ const zhHans = {
|
||||
selectModelAbilities: '选择模型能力',
|
||||
visionAbility: '视觉能力',
|
||||
functionCallAbility: '函数调用',
|
||||
contextLength: '上下文窗口',
|
||||
contextLengthPlaceholder: '未知',
|
||||
contextLengthInvalid: '上下文窗口必须是正整数',
|
||||
extraParameters: '额外参数',
|
||||
addParameter: '添加参数',
|
||||
keyName: '键名',
|
||||
|
||||
@@ -193,6 +193,9 @@ const zhHant = {
|
||||
selectModelAbilities: '選擇模型能力',
|
||||
visionAbility: '視覺能力',
|
||||
functionCallAbility: '函數呼叫',
|
||||
contextLength: '上下文視窗',
|
||||
contextLengthPlaceholder: '未知',
|
||||
contextLengthInvalid: '上下文視窗必須是正整數',
|
||||
extraParameters: '額外參數',
|
||||
addParameter: '新增參數',
|
||||
keyName: '鍵名',
|
||||
|
||||
Reference in New Issue
Block a user