diff --git a/src/langbot/pkg/entity/persistence/model.py b/src/langbot/pkg/entity/persistence/model.py index 3c96acd7..5b5f1fe2 100644 --- a/src/langbot/pkg/entity/persistence/model.py +++ b/src/langbot/pkg/entity/persistence/model.py @@ -31,6 +31,7 @@ class LLMModel(Base): name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[]) + context_length = sqlalchemy.Column(sqlalchemy.Integer, nullable=True) extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) diff --git a/src/langbot/pkg/persistence/alembic/versions/0004_add_llm_model_context_length.py b/src/langbot/pkg/persistence/alembic/versions/0004_add_llm_model_context_length.py new file mode 100644 index 00000000..6bc9ddd2 --- /dev/null +++ b/src/langbot/pkg/persistence/alembic/versions/0004_add_llm_model_context_length.py @@ -0,0 +1,30 @@ +"""add llm model context length + +Revision ID: 0004_add_llm_model_context_length +Revises: 0003_add_rerank_models +Create Date: 2026-06-07 +""" + +import sqlalchemy as sa +from alembic import op + +revision = '0004_add_llm_model_context_length' +down_revision = '0003_add_rerank_models' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + columns = {column['name'] for column in inspector.get_columns('llm_models')} + if 'context_length' not in columns: + op.add_column('llm_models', sa.Column('context_length', sa.Integer(), nullable=True)) + + +def downgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + columns = {column['name'] for column in inspector.get_columns('llm_models')} + if 'context_length' in columns: + op.drop_column('llm_models', 'context_length') diff --git a/src/langbot/pkg/persistence/migrations/dbm026_llm_model_context_length.py b/src/langbot/pkg/persistence/migrations/dbm026_llm_model_context_length.py new file mode 100644 index 00000000..81d7031e --- /dev/null +++ b/src/langbot/pkg/persistence/migrations/dbm026_llm_model_context_length.py @@ -0,0 +1,42 @@ +import sqlalchemy +from .. import migration + + +@migration.migration_class(26) +class DBMigrateLLMModelContextLength(migration.DBMigration): + """Add context_length column to LLM models""" + + async def upgrade(self): + columns = await self._get_columns('llm_models') + if 'context_length' not in columns: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN context_length INTEGER') + ) + + async def downgrade(self): + columns = await self._get_columns('llm_models') + if 'context_length' not in columns: + return + + if self.ap.persistence_mgr.db.name == 'postgresql': + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE llm_models DROP COLUMN IF EXISTS context_length') + ) + else: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE llm_models DROP COLUMN context_length') + ) + + async def _get_columns(self, table_name: str) -> set[str]: + if self.ap.persistence_mgr.db.name == 'postgresql': + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.text(""" + SELECT column_name FROM information_schema.columns + WHERE table_name = :table_name + """), + {'table_name': table_name}, + ) + return {row[0] for row in result.fetchall()} + + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name})')) + return {row[1] for row in result.fetchall()} diff --git a/src/langbot/pkg/provider/modelmgr/modelmgr.py b/src/langbot/pkg/provider/modelmgr/modelmgr.py index 3609f147..0c577f1a 100644 --- a/src/langbot/pkg/provider/modelmgr/modelmgr.py +++ b/src/langbot/pkg/provider/modelmgr/modelmgr.py @@ -266,6 +266,7 @@ class ModelManager: name=model_info.get('name', ''), provider_uuid='', abilities=model_info.get('abilities', []), + context_length=model_info.get('context_length'), extra_args=model_info.get('extra_args', {}), ), provider=runtime_provider, @@ -460,6 +461,7 @@ class ModelManager: name=model_info.get('name', ''), provider_uuid=model_info.get('provider_uuid', ''), abilities=model_info.get('abilities', []), + context_length=model_info.get('context_length'), extra_args=model_info.get('extra_args', {}), ) diff --git a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py index 236d4723..3c31bac9 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/litellmchat.py @@ -59,8 +59,15 @@ class LiteLLMRequester(requester.ProviderAPIRequester): litellm_model_name = self._build_litellm_model_name(model_name) if litellm_model_name != model_name: candidates.append((litellm_model_name, None)) + for metadata_provider in self._metadata_provider_candidates(model_name): + candidates.append((f'{metadata_provider}/{model_name}', None)) + tried_candidates: set[tuple[str, str | None]] = set() for candidate_model, candidate_provider in candidates: + candidate_key = (candidate_model, candidate_provider) + if candidate_key in tried_candidates: + continue + tried_candidates.add(candidate_key) try: if bool(helper(model=candidate_model, custom_llm_provider=candidate_provider)): return True @@ -68,24 +75,80 @@ class LiteLLMRequester(requester.ProviderAPIRequester): continue return False + def _context_length_from_scan_payload(self, model_payload: dict[str, typing.Any] | None) -> int | None: + if not model_payload: + return None + + for field_name in ('context_length', 'context_window', 'max_context_length'): + value = model_payload.get(field_name) + if isinstance(value, bool): + continue + if isinstance(value, int) and value > 0: + return value + if isinstance(value, str) and value.isdigit(): + parsed_value = int(value) + if parsed_value > 0: + return parsed_value + return None + + def _metadata_provider_candidates(self, model_name: str) -> list[str]: + normalized_model_name = (model_name or '').lower() + candidates = [] + if normalized_model_name.startswith(('moonshot-', 'kimi-')): + candidates.append('moonshot') + if normalized_model_name.startswith('deepseek-'): + candidates.append('deepseek') + + base_url = self.requester_cfg.get('base_url', '').lower() + if 'moonshot' in base_url: + candidates.append('moonshot') + if 'deepseek' in base_url: + candidates.append('deepseek') + + deduped_candidates = [] + for candidate in candidates: + if candidate not in deduped_candidates: + deduped_candidates.append(candidate) + return deduped_candidates + + def _known_context_length_fallback(self, model_name: str) -> int | None: + normalized_model_name = (model_name or '').lower() + if normalized_model_name.startswith('deepseek-v4-'): + return 1_000_000 + if normalized_model_name.startswith(('kimi-k2.5', 'kimi-k2.6')): + return 256 * 1024 + if normalized_model_name.startswith('moonshot-v1-8k'): + return 8 * 1024 + if normalized_model_name.startswith('moonshot-v1-32k'): + return 32 * 1024 + if normalized_model_name.startswith('moonshot-v1-128k') or normalized_model_name == 'moonshot-v1-auto': + return 128 * 1024 + return None + def _safe_context_length(self, model_name: str) -> int | None: helper = getattr(litellm, 'get_max_tokens', None) if not callable(helper): - return None + return self._known_context_length_fallback(model_name) candidates = [model_name] litellm_model_name = self._build_litellm_model_name(model_name) if litellm_model_name != model_name: candidates.append(litellm_model_name) + for provider in self._metadata_provider_candidates(model_name): + candidates.append(f'{provider}/{model_name}') + tried_candidates = [] for candidate in candidates: + if candidate in tried_candidates: + continue + tried_candidates.append(candidate) try: max_tokens = helper(candidate) except Exception: continue if isinstance(max_tokens, int) and max_tokens > 0: return max_tokens - return None + return self._known_context_length_fallback(model_name) def _supports_function_calling(self, model_name: str) -> bool: return self._safe_litellm_bool_helper('supports_function_calling', model_name) @@ -101,7 +164,11 @@ class LiteLLMRequester(requester.ProviderAPIRequester): return 'embedding' return 'llm' - def _enrich_scanned_model(self, model_id: str) -> dict[str, typing.Any]: + def _enrich_scanned_model( + self, + model_id: str, + model_payload: dict[str, typing.Any] | None = None, + ) -> dict[str, typing.Any]: model_type = self._infer_model_type(model_id) scanned_model: dict[str, typing.Any] = { 'id': model_id, @@ -113,11 +180,17 @@ class LiteLLMRequester(requester.ProviderAPIRequester): abilities = [] if self._supports_function_calling(model_id): abilities.append('func_call') - if self._supports_vision(model_id): + supports_provider_reported_vision = bool( + model_payload + and (model_payload.get('supports_image_in') is True or model_payload.get('supports_vision') is True) + ) + if supports_provider_reported_vision or self._supports_vision(model_id): abilities.append('vision') scanned_model['abilities'] = abilities - context_length = self._safe_context_length(model_id) + context_length = self._context_length_from_scan_payload(model_payload) + if context_length is None: + context_length = self._safe_context_length(model_id) if context_length is not None: scanned_model['context_length'] = context_length @@ -557,7 +630,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester): if not model_id: continue - models.append(self._enrich_scanned_model(model_id)) + models.append(self._enrich_scanned_model(model_id, item)) models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower())) diff --git a/src/langbot/pkg/utils/constants.py b/src/langbot/pkg/utils/constants.py index 4fad9069..f97255ab 100644 --- a/src/langbot/pkg/utils/constants.py +++ b/src/langbot/pkg/utils/constants.py @@ -2,7 +2,7 @@ import langbot semantic_version = f'v{langbot.__version__}' -required_database_version = 25 +required_database_version = 26 """Tag the version of the database schema, used to check if the database needs to be migrated""" debug_mode = False diff --git a/tests/unit_tests/api/service/test_model_service.py b/tests/unit_tests/api/service/test_model_service.py index 6e6d2598..18d7ffbe 100644 --- a/tests/unit_tests/api/service/test_model_service.py +++ b/tests/unit_tests/api/service/test_model_service.py @@ -35,6 +35,7 @@ def _create_mock_llm_model( name: str = 'Test LLM', provider_uuid: str = 'provider-uuid', abilities: list = None, + context_length: int | None = None, extra_args: dict = None, ) -> Mock: """Helper to create mock LLMModel entity.""" @@ -43,6 +44,7 @@ def _create_mock_llm_model( model.name = name model.provider_uuid = provider_uuid model.abilities = abilities or [] + model.context_length = context_length model.extra_args = extra_args or {} return model @@ -142,10 +144,12 @@ class TestRuntimeModelData: 'name': 'Model', 'provider_uuid': 'provider', 'abilities': ['vision'], + 'context_length': 128000, 'extra_args': {'temp': 0.7}, } result = _runtime_model_data('uuid', update_payload) assert result['abilities'] == ['vision'] + assert result['context_length'] == 128000 assert result['extra_args'] == {'temp': 0.7} @@ -188,7 +192,7 @@ class TestLLMModelsServiceGetLLMModels: ap = SimpleNamespace() ap.persistence_mgr = SimpleNamespace() - model = _create_mock_llm_model() + model = _create_mock_llm_model(context_length=128000) provider = _create_mock_provider() mock_model_result = _create_mock_result([model]) @@ -206,6 +210,7 @@ class TestLLMModelsServiceGetLLMModels: 'uuid': entity.uuid, 'name': entity.name, 'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None, + 'context_length': getattr(entity, 'context_length', None), 'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None, } ) @@ -218,6 +223,7 @@ class TestLLMModelsServiceGetLLMModels: # Verify assert len(result) == 1 assert result[0]['name'] == 'Test LLM' + assert result[0]['context_length'] == 128000 async def test_get_llm_models_hide_secret_keys(self): """Hides secret API keys when include_secret=False.""" @@ -265,7 +271,7 @@ class TestLLMModelsServiceGetLLMModel: ap = SimpleNamespace() ap.persistence_mgr = SimpleNamespace() - model = _create_mock_llm_model(model_uuid='found-uuid') + model = _create_mock_llm_model(model_uuid='found-uuid', context_length=128000) provider = _create_mock_provider() mock_model_result = _create_mock_result([], first_item=model) @@ -279,11 +285,12 @@ class TestLLMModelsServiceGetLLMModel: ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute) ap.persistence_mgr.serialize_model = Mock( - return_value={ - 'uuid': 'found-uuid', - 'name': 'Test LLM', - 'provider_uuid': 'provider-uuid', - 'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']}, + side_effect=lambda model_cls, entity: { + 'uuid': entity.uuid, + 'name': entity.name, + 'provider_uuid': getattr(entity, 'provider_uuid', None), + 'context_length': getattr(entity, 'context_length', None), + 'api_keys': getattr(entity, 'api_keys', None), } ) @@ -295,6 +302,7 @@ class TestLLMModelsServiceGetLLMModel: # Verify assert result is not None assert result['uuid'] == 'found-uuid' + assert result['context_length'] == 128000 async def test_get_llm_model_not_found(self): """Returns None when model not found.""" @@ -402,6 +410,39 @@ class TestLLMModelsServiceCreateLLMModel: # Verify assert model_uuid == 'preserved-uuid' + async def test_create_llm_model_persists_context_length_as_column(self): + """Creates LLM model with context_length outside extra_args.""" + ap = SimpleNamespace() + ap.persistence_mgr = SimpleNamespace() + ap.model_mgr = SimpleNamespace() + ap.model_mgr.provider_dict = {'provider-uuid': Mock()} + ap.model_mgr.llm_models = [] + ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock()) + ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock()) + + mock_result = _create_mock_result([]) + ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result) + + service = LLMModelsService(ap) + + await service.create_llm_model( + { + 'uuid': 'model-with-context', + 'name': 'Context Model', + 'provider_uuid': 'provider-uuid', + 'abilities': ['func_call'], + 'context_length': 128000, + 'extra_args': {'temperature': 0.2}, + }, + preserve_uuid=True, + auto_set_to_default_pipeline=False, + ) + + runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0] + assert runtime_entity.context_length == 128000 + assert runtime_entity.extra_args == {'temperature': 0.2} + assert 'context_length' not in runtime_entity.extra_args + async def test_create_llm_model_provider_not_found_raises_error(self): """Raises Exception when provider not found in runtime.""" # Setup @@ -512,6 +553,35 @@ class TestLLMModelsServiceUpdateLLMModel: 'provider_uuid': 'nonexistent-provider', }) + async def test_update_llm_model_reloads_context_length_as_column(self): + """Updates runtime model with context_length outside extra_args.""" + ap = SimpleNamespace() + ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock()) + ap.model_mgr = SimpleNamespace() + ap.model_mgr.provider_dict = {'provider-uuid': Mock()} + ap.model_mgr.llm_models = [] + ap.model_mgr.remove_llm_model = AsyncMock() + ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock()) + + service = LLMModelsService(ap) + + await service.update_llm_model( + 'existing-uuid', + { + 'name': 'Updated Name', + 'provider_uuid': 'provider-uuid', + 'abilities': ['vision'], + 'context_length': 64000, + 'extra_args': {'temperature': 0.4}, + }, + ) + + runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0] + assert runtime_entity.uuid == 'existing-uuid' + assert runtime_entity.context_length == 64000 + assert runtime_entity.extra_args == {'temperature': 0.4} + assert 'context_length' not in runtime_entity.extra_args + class TestLLMModelsServiceDeleteLLMModel: """Tests for LLMModelsService.delete_llm_model method.""" @@ -961,4 +1031,4 @@ class TestRerankModelsServiceGetRerankModelsByProvider: result = await service.get_rerank_models_by_provider('provider-uuid') # Verify - assert len(result) == 2 \ No newline at end of file + assert len(result) == 2 diff --git a/tests/unit_tests/provider/test_litellmchat.py b/tests/unit_tests/provider/test_litellmchat.py index f44ba4ba..e4efd4b8 100644 --- a/tests/unit_tests/provider/test_litellmchat.py +++ b/tests/unit_tests/provider/test_litellmchat.py @@ -896,6 +896,121 @@ class TestScanModels: assert by_id['text-embedding-3-small']['type'] == 'embedding' assert by_id['bge-reranker-v2']['type'] == 'rerank' + @pytest.mark.asyncio + async def test_scan_models_prefers_context_length_from_provider_payload(self): + """Provider-supplied context_length is preserved before LiteLLM metadata fallback.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.moonshot.cn/v1', + 'timeout': 60, + }, + ) + requester._supports_function_calling = Mock(return_value=False) + requester._supports_vision = Mock(return_value=False) + requester._safe_context_length = Mock(return_value=None) + + mock_response = Mock() + mock_response.json = Mock( + return_value={ + 'data': [ + {'id': 'moonshot-v1-128k', 'context_length': 131072}, + ] + } + ) + mock_response.raise_for_status = Mock() + + with patch('httpx.AsyncClient') as mock_client: + mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock()) + mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + + result = await requester.scan_models(api_key='test-key') + + assert result['models'][0]['context_length'] == 131072 + requester._safe_context_length.assert_not_called() + + def test_safe_context_length_tries_moonshot_metadata_alias(self): + """OpenAI-compatible Moonshot endpoints still use Moonshot metadata for context windows.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.moonshot.cn/v1', + 'custom_llm_provider': 'openai', + }, + ) + + with patch.object(litellmchat.litellm, 'get_max_tokens') as mock_get_max_tokens: + mock_get_max_tokens.side_effect = lambda model: 131072 if model == 'moonshot/moonshot-v1-128k' else None + + assert requester._safe_context_length('moonshot-v1-128k') == 131072 + + def test_litellm_bool_helper_tries_moonshot_metadata_alias(self): + """OpenAI-compatible Moonshot endpoints still use Moonshot metadata for abilities.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.moonshot.cn/v1', + 'custom_llm_provider': 'openai', + }, + ) + + with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling: + mock_supports_function_calling.side_effect = ( + lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6' + and custom_llm_provider is None + ) + + assert requester._supports_function_calling('kimi-k2.6') is True + + @pytest.mark.asyncio + async def test_scan_models_uses_provider_payload_for_vision_ability(self): + """Provider-supplied vision support is used when scanning models.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.moonshot.cn/v1', + 'timeout': 60, + }, + ) + requester._supports_function_calling = Mock(return_value=False) + requester._supports_vision = Mock(return_value=False) + requester._safe_context_length = Mock(return_value=None) + + mock_response = Mock() + mock_response.json = Mock( + return_value={ + 'data': [ + { + 'id': 'moonshot-v1-128k-vision-preview', + 'supports_image_in': True, + }, + ] + } + ) + mock_response.raise_for_status = Mock() + + with patch('httpx.AsyncClient') as mock_client: + mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock()) + mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + + result = await requester.scan_models(api_key='test-key') + + assert result['models'][0]['abilities'] == ['vision'] + + def test_safe_context_length_falls_back_for_deepseek_v4_models(self): + """DeepSeek V4 API ids have a known 1M context even before LiteLLM maps them.""" + requester = litellmchat.LiteLLMRequester( + ap=Mock(), + config={ + 'base_url': 'https://api.deepseek.com', + 'custom_llm_provider': 'deepseek', + }, + ) + + with patch.object(litellmchat.litellm, 'get_max_tokens', side_effect=Exception('not mapped')): + assert requester._safe_context_length('deepseek-v4-pro') == 1_000_000 + assert requester._safe_context_length('deepseek-v4-flash') == 1_000_000 + @pytest.mark.asyncio async def test_scan_models_no_base_url(self): """Test scan_models without base_url raises error""" diff --git a/tests/unit_tests/provider/test_model_manager.py b/tests/unit_tests/provider/test_model_manager.py index b38a5d02..b6e82d3f 100644 --- a/tests/unit_tests/provider/test_model_manager.py +++ b/tests/unit_tests/provider/test_model_manager.py @@ -494,6 +494,7 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg 'api_keys': ['temp-key'], }, 'abilities': ['func_call'], + 'context_length': 128000, 'extra_args': {'temperature': 0.5}, } @@ -501,6 +502,9 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg assert runtime_model.model_entity.uuid == 'temp-model-uuid' assert runtime_model.model_entity.name == 'TempModel' + assert runtime_model.model_entity.context_length == 128000 + assert runtime_model.model_entity.extra_args == {'temperature': 0.5} + assert 'context_length' not in runtime_model.model_entity.extra_args assert runtime_model.provider.provider_entity.uuid == 'temp-provider-uuid' assert runtime_model.provider.token_mgr.tokens == ['temp-key'] @@ -785,4 +789,4 @@ def test_provider_not_found_error_str(): error = provider_errors.ProviderNotFoundError('test-provider') assert str(error) == 'Provider test-provider not found' - assert error.provider_name == 'test-provider' \ No newline at end of file + assert error.provider_name == 'test-provider' diff --git a/web/src/app/home/components/models-dialog/ModelsDialog.tsx b/web/src/app/home/components/models-dialog/ModelsDialog.tsx index 16c6663d..78bccf81 100644 --- a/web/src/app/home/components/models-dialog/ModelsDialog.tsx +++ b/web/src/app/home/components/models-dialog/ModelsDialog.tsx @@ -64,6 +64,17 @@ function convertExtraArgsToObject( return obj; } +function parseContextLength( + value: number | null | undefined, + invalidMessage: string, +): number | null { + if (value === undefined || value === null) return null; + if (!Number.isInteger(value) || value <= 0) { + throw new Error(invalidMessage); + } + return value; +} + export default function ModelsDialog({ open, onOpenChange, @@ -254,6 +265,7 @@ export default function ModelsDialog({ name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) { if (!name.trim()) { toast.error(t('models.modelNameRequired')); @@ -268,6 +280,10 @@ export default function ModelsDialog({ name, provider_uuid: providerUuid, abilities, + context_length: parseContextLength( + contextLength, + t('models.contextLengthInvalid'), + ), extra_args: extraArgsObj, } as never); } else if (modelType === 'embedding') { @@ -325,6 +341,7 @@ export default function ModelsDialog({ name: item.model.name, provider_uuid: providerUuid, abilities: item.abilities, + context_length: item.model.context_length ?? null, extra_args: {}, } as never); } else if (effectiveType === 'embedding') { @@ -361,6 +378,7 @@ export default function ModelsDialog({ name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) { if (!name.trim()) { toast.error(t('models.modelNameRequired')); @@ -375,6 +393,10 @@ export default function ModelsDialog({ name, provider_uuid: providerUuid, abilities, + context_length: parseContextLength( + contextLength, + t('models.contextLengthInvalid'), + ), extra_args: extraArgsObj, } as never); } else if (modelType === 'embedding') { @@ -509,8 +531,15 @@ export default function ModelsDialog({ onSpaceLogin={handleSpaceLogin} onOpenAddModel={() => setAddModelPopoverOpen(provider.uuid)} onCloseAddModel={() => setAddModelPopoverOpen(null)} - onAddModel={(modelType, name, abilities, extraArgs) => - handleAddModel(provider.uuid, modelType, name, abilities, extraArgs) + onAddModel={(modelType, name, abilities, extraArgs, contextLength) => + handleAddModel( + provider.uuid, + modelType, + name, + abilities, + extraArgs, + contextLength, + ) } onScanModels={(modelType) => handleScanModels(provider.uuid, modelType)} onAddScannedModels={(modelType, models) => @@ -518,7 +547,14 @@ export default function ModelsDialog({ } onOpenEditModel={(modelId) => setEditModelPopoverOpen(modelId)} onCloseEditModel={() => setEditModelPopoverOpen(null)} - onUpdateModel={(modelId, modelType, name, abilities, extraArgs) => + onUpdateModel={( + modelId, + modelType, + name, + abilities, + extraArgs, + contextLength, + ) => handleUpdateModel( provider.uuid, modelId, @@ -526,6 +562,7 @@ export default function ModelsDialog({ name, abilities, extraArgs, + contextLength, ) } onOpenDeleteConfirm={(modelId) => setDeleteConfirmOpen(modelId)} diff --git a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx index 382b5a0f..0bb2d1c2 100644 --- a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx +++ b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx @@ -41,6 +41,7 @@ interface AddModelPopoverProps { name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) => Promise; onScanModels: (modelType?: ModelType) => Promise; onAddScannedModels: ( @@ -81,6 +82,7 @@ export default function AddModelPopover({ const [mode, setMode] = useState<'manual' | 'scan'>('manual'); const [name, setName] = useState(''); const [abilities, setAbilities] = useState([]); + const [contextLength, setContextLength] = useState(''); const [extraArgs, setExtraArgs] = useState([]); const [scanLoading, setScanLoading] = useState(false); const [scannedModels, setScannedModels] = useState( @@ -98,6 +100,7 @@ export default function AddModelPopover({ setMode(initialMode); setName(''); setAbilities([]); + setContextLength(''); setExtraArgs([]); setScanLoading(false); setScannedModels([]); @@ -119,7 +122,11 @@ export default function AddModelPopover({ }, [tab, mode]); const handleAdd = async () => { - await onAddModel(tab, name, abilities, extraArgs); + const parsedContextLength = + tab === 'llm' && contextLength.trim() + ? Number(contextLength.trim()) + : null; + await onAddModel(tab, name, abilities, extraArgs, parsedContextLength); }; const handleTest = async () => { @@ -318,6 +325,24 @@ export default function AddModelPopover({ )} + {tab === 'llm' && ( +
+ + setContextLength(e.target.value)} + /> +
+ )} + Promise; onTestModel: ( name: string, @@ -92,6 +93,11 @@ export default function ModelItem({ const [editAbilities, setEditAbilities] = useState( modelType === 'llm' ? (model as LLMModel).abilities || [] : [], ); + const [editContextLength, setEditContextLength] = useState( + modelType === 'llm' && (model as LLMModel).context_length + ? String((model as LLMModel).context_length) + : '', + ); const [editExtraArgs, setEditExtraArgs] = useState( convertExtraArgsToArray(model.extra_args), ); @@ -106,13 +112,27 @@ export default function ModelItem({ setEditAbilities( modelType === 'llm' ? (model as LLMModel).abilities || [] : [], ); + setEditContextLength( + modelType === 'llm' && (model as LLMModel).context_length + ? String((model as LLMModel).context_length) + : '', + ); setEditExtraArgs(convertExtraArgsToArray(model.extra_args)); onResetTestResult(); } }, [isEditOpen]); const handleSave = async () => { - await onUpdateModel(editName, editAbilities, editExtraArgs); + const parsedContextLength = + modelType === 'llm' && editContextLength.trim() + ? Number(editContextLength.trim()) + : null; + await onUpdateModel( + editName, + editAbilities, + editExtraArgs, + parsedContextLength, + ); }; const handleTest = async () => { @@ -287,6 +307,25 @@ export default function ModelItem({ )} + {modelType === 'llm' && ( +
+ + setEditContextLength(e.target.value)} + /> +
+ )} + Promise; onScanModels: (modelType?: ModelType) => Promise; onAddScannedModels: ( @@ -74,6 +75,7 @@ interface ProviderCardProps { name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) => Promise; onOpenDeleteConfirm: (modelId: string) => void; onCloseDeleteConfirm: () => void; @@ -405,13 +407,19 @@ export default function ProviderCard({ onOpenDeleteConfirm={onOpenDeleteConfirm} onCloseDeleteConfirm={onCloseDeleteConfirm} onDeleteModel={() => onDeleteModel(model.uuid, 'llm')} - onUpdateModel={(name, abilities, extraArgs) => + onUpdateModel={( + name, + abilities, + extraArgs, + contextLength, + ) => onUpdateModel( model.uuid, 'llm', name, abilities, extraArgs, + contextLength, ) } onTestModel={(name, abilities, extraArgs) => diff --git a/web/src/app/home/components/models-dialog/types.ts b/web/src/app/home/components/models-dialog/types.ts index d2ecb7f1..cc52906a 100644 --- a/web/src/app/home/components/models-dialog/types.ts +++ b/web/src/app/home/components/models-dialog/types.ts @@ -53,6 +53,7 @@ export interface ModelItemProps { name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) => Promise; onTest: ( name: string, @@ -89,6 +90,7 @@ export interface ProviderCardProps { name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) => Promise; onScanModels: (modelType?: ModelType) => Promise; onAddScannedModels: ( @@ -103,6 +105,7 @@ export interface ProviderCardProps { name: string, abilities: string[], extraArgs: ExtraArg[], + contextLength?: number | null, ) => Promise; onOpenDeleteConfirm: (modelId: string) => void; onCloseDeleteConfirm: () => void; diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index 171b4d47..5a7694d2 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -96,6 +96,7 @@ export interface LLMModel { provider_uuid: string; provider?: ModelProvider; abilities?: string[]; + context_length?: number | null; extra_args?: object; } diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 17e82d0d..a1d3e78a 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -201,6 +201,9 @@ const enUS = { selectModelAbilities: 'Select model abilities', visionAbility: 'Vision Ability', functionCallAbility: 'Function Call', + contextLength: 'Context Window', + contextLengthPlaceholder: 'Unknown', + contextLengthInvalid: 'Context window must be a positive integer', extraParameters: 'Extra Parameters', addParameter: 'Add Parameter', keyName: 'Key Name', diff --git a/web/src/i18n/locales/es-ES.ts b/web/src/i18n/locales/es-ES.ts index 40f30bbd..5355e5d8 100644 --- a/web/src/i18n/locales/es-ES.ts +++ b/web/src/i18n/locales/es-ES.ts @@ -206,6 +206,9 @@ const esES = { selectModelAbilities: 'Seleccionar capacidades del modelo', visionAbility: 'Capacidad de visión', functionCallAbility: 'Llamada a funciones', + contextLength: 'Ventana de contexto', + contextLengthPlaceholder: 'Desconocido', + contextLengthInvalid: 'La ventana de contexto debe ser un entero positivo', extraParameters: 'Parámetros adicionales', addParameter: 'Añadir parámetro', keyName: 'Nombre de la clave', diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index 845b6cc7..8012d55a 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -204,6 +204,10 @@ const jaJP = { selectModelAbilities: 'モデル機能を選択', visionAbility: '視覚機能', functionCallAbility: '関数呼び出し', + contextLength: 'コンテキストウィンドウ', + contextLengthPlaceholder: '不明', + contextLengthInvalid: + 'コンテキストウィンドウは正の整数である必要があります', extraParameters: '追加パラメータ', addParameter: 'パラメータを追加', keyName: 'キー名', diff --git a/web/src/i18n/locales/ru-RU.ts b/web/src/i18n/locales/ru-RU.ts index 66421cbb..bbc0cf09 100644 --- a/web/src/i18n/locales/ru-RU.ts +++ b/web/src/i18n/locales/ru-RU.ts @@ -203,6 +203,10 @@ const ruRU = { selectModelAbilities: 'Выберите возможности модели', visionAbility: 'Распознавание изображений', functionCallAbility: 'Вызов функций', + contextLength: 'Контекстное окно', + contextLengthPlaceholder: 'Неизвестно', + contextLengthInvalid: + 'Контекстное окно должно быть положительным целым числом', extraParameters: 'Дополнительные параметры', addParameter: 'Добавить параметр', keyName: 'Имя ключа', diff --git a/web/src/i18n/locales/th-TH.ts b/web/src/i18n/locales/th-TH.ts index 4e6e32f8..7be3480a 100644 --- a/web/src/i18n/locales/th-TH.ts +++ b/web/src/i18n/locales/th-TH.ts @@ -199,6 +199,9 @@ const thTH = { selectModelAbilities: 'เลือกความสามารถของโมเดล', visionAbility: 'ความสามารถด้านภาพ', functionCallAbility: 'การเรียกฟังก์ชัน', + contextLength: 'หน้าต่างบริบท', + contextLengthPlaceholder: 'ไม่ทราบ', + contextLengthInvalid: 'หน้าต่างบริบทต้องเป็นจำนวนเต็มบวก', extraParameters: 'พารามิเตอร์เพิ่มเติม', addParameter: 'เพิ่มพารามิเตอร์', keyName: 'ชื่อคีย์', diff --git a/web/src/i18n/locales/vi-VN.ts b/web/src/i18n/locales/vi-VN.ts index ffd49fa0..f15a5db1 100644 --- a/web/src/i18n/locales/vi-VN.ts +++ b/web/src/i18n/locales/vi-VN.ts @@ -203,6 +203,9 @@ const viVN = { selectModelAbilities: 'Chọn khả năng mô hình', visionAbility: 'Khả năng thị giác', functionCallAbility: 'Gọi hàm', + contextLength: 'Cửa sổ ngữ cảnh', + contextLengthPlaceholder: 'Không rõ', + contextLengthInvalid: 'Cửa sổ ngữ cảnh phải là số nguyên dương', extraParameters: 'Tham số bổ sung', addParameter: 'Thêm tham số', keyName: 'Tên khóa', diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 227a851b..d9f3c65a 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -193,6 +193,9 @@ const zhHans = { selectModelAbilities: '选择模型能力', visionAbility: '视觉能力', functionCallAbility: '函数调用', + contextLength: '上下文窗口', + contextLengthPlaceholder: '未知', + contextLengthInvalid: '上下文窗口必须是正整数', extraParameters: '额外参数', addParameter: '添加参数', keyName: '键名', diff --git a/web/src/i18n/locales/zh-Hant.ts b/web/src/i18n/locales/zh-Hant.ts index 855c1756..ed8de946 100644 --- a/web/src/i18n/locales/zh-Hant.ts +++ b/web/src/i18n/locales/zh-Hant.ts @@ -193,6 +193,9 @@ const zhHant = { selectModelAbilities: '選擇模型能力', visionAbility: '視覺能力', functionCallAbility: '函數呼叫', + contextLength: '上下文視窗', + contextLengthPlaceholder: '未知', + contextLengthInvalid: '上下文視窗必須是正整數', extraParameters: '額外參數', addParameter: '新增參數', keyName: '鍵名',