feat(models): persist context metadata

This commit is contained in:
huanghuoguoguo
2026-06-08 00:39:30 +08:00
parent 573e1fe36e
commit b82db2b7f8
23 changed files with 498 additions and 22 deletions

View File

@@ -31,6 +31,7 @@ class LLMModel(Base):
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
context_length = sqlalchemy.Column(sqlalchemy.Integer, nullable=True)
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())

View File

@@ -0,0 +1,30 @@
"""add llm model context length
Revision ID: 0004_add_llm_model_context_length
Revises: 0003_add_rerank_models
Create Date: 2026-06-07
"""
import sqlalchemy as sa
from alembic import op
revision = '0004_add_llm_model_context_length'
down_revision = '0003_add_rerank_models'
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = {column['name'] for column in inspector.get_columns('llm_models')}
if 'context_length' not in columns:
op.add_column('llm_models', sa.Column('context_length', sa.Integer(), nullable=True))
def downgrade() -> None:
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = {column['name'] for column in inspector.get_columns('llm_models')}
if 'context_length' in columns:
op.drop_column('llm_models', 'context_length')

View File

@@ -0,0 +1,42 @@
import sqlalchemy
from .. import migration
@migration.migration_class(26)
class DBMigrateLLMModelContextLength(migration.DBMigration):
"""Add context_length column to LLM models"""
async def upgrade(self):
columns = await self._get_columns('llm_models')
if 'context_length' not in columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE llm_models ADD COLUMN context_length INTEGER')
)
async def downgrade(self):
columns = await self._get_columns('llm_models')
if 'context_length' not in columns:
return
if self.ap.persistence_mgr.db.name == 'postgresql':
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE llm_models DROP COLUMN IF EXISTS context_length')
)
else:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text('ALTER TABLE llm_models DROP COLUMN context_length')
)
async def _get_columns(self, table_name: str) -> set[str]:
if self.ap.persistence_mgr.db.name == 'postgresql':
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("""
SELECT column_name FROM information_schema.columns
WHERE table_name = :table_name
"""),
{'table_name': table_name},
)
return {row[0] for row in result.fetchall()}
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text(f'PRAGMA table_info({table_name})'))
return {row[1] for row in result.fetchall()}

View File

@@ -266,6 +266,7 @@ class ModelManager:
name=model_info.get('name', ''),
provider_uuid='',
abilities=model_info.get('abilities', []),
context_length=model_info.get('context_length'),
extra_args=model_info.get('extra_args', {}),
),
provider=runtime_provider,
@@ -460,6 +461,7 @@ class ModelManager:
name=model_info.get('name', ''),
provider_uuid=model_info.get('provider_uuid', ''),
abilities=model_info.get('abilities', []),
context_length=model_info.get('context_length'),
extra_args=model_info.get('extra_args', {}),
)

View File

@@ -59,8 +59,15 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
litellm_model_name = self._build_litellm_model_name(model_name)
if litellm_model_name != model_name:
candidates.append((litellm_model_name, None))
for metadata_provider in self._metadata_provider_candidates(model_name):
candidates.append((f'{metadata_provider}/{model_name}', None))
tried_candidates: set[tuple[str, str | None]] = set()
for candidate_model, candidate_provider in candidates:
candidate_key = (candidate_model, candidate_provider)
if candidate_key in tried_candidates:
continue
tried_candidates.add(candidate_key)
try:
if bool(helper(model=candidate_model, custom_llm_provider=candidate_provider)):
return True
@@ -68,24 +75,80 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
continue
return False
def _context_length_from_scan_payload(self, model_payload: dict[str, typing.Any] | None) -> int | None:
if not model_payload:
return None
for field_name in ('context_length', 'context_window', 'max_context_length'):
value = model_payload.get(field_name)
if isinstance(value, bool):
continue
if isinstance(value, int) and value > 0:
return value
if isinstance(value, str) and value.isdigit():
parsed_value = int(value)
if parsed_value > 0:
return parsed_value
return None
def _metadata_provider_candidates(self, model_name: str) -> list[str]:
normalized_model_name = (model_name or '').lower()
candidates = []
if normalized_model_name.startswith(('moonshot-', 'kimi-')):
candidates.append('moonshot')
if normalized_model_name.startswith('deepseek-'):
candidates.append('deepseek')
base_url = self.requester_cfg.get('base_url', '').lower()
if 'moonshot' in base_url:
candidates.append('moonshot')
if 'deepseek' in base_url:
candidates.append('deepseek')
deduped_candidates = []
for candidate in candidates:
if candidate not in deduped_candidates:
deduped_candidates.append(candidate)
return deduped_candidates
def _known_context_length_fallback(self, model_name: str) -> int | None:
normalized_model_name = (model_name or '').lower()
if normalized_model_name.startswith('deepseek-v4-'):
return 1_000_000
if normalized_model_name.startswith(('kimi-k2.5', 'kimi-k2.6')):
return 256 * 1024
if normalized_model_name.startswith('moonshot-v1-8k'):
return 8 * 1024
if normalized_model_name.startswith('moonshot-v1-32k'):
return 32 * 1024
if normalized_model_name.startswith('moonshot-v1-128k') or normalized_model_name == 'moonshot-v1-auto':
return 128 * 1024
return None
def _safe_context_length(self, model_name: str) -> int | None:
helper = getattr(litellm, 'get_max_tokens', None)
if not callable(helper):
return None
return self._known_context_length_fallback(model_name)
candidates = [model_name]
litellm_model_name = self._build_litellm_model_name(model_name)
if litellm_model_name != model_name:
candidates.append(litellm_model_name)
for provider in self._metadata_provider_candidates(model_name):
candidates.append(f'{provider}/{model_name}')
tried_candidates = []
for candidate in candidates:
if candidate in tried_candidates:
continue
tried_candidates.append(candidate)
try:
max_tokens = helper(candidate)
except Exception:
continue
if isinstance(max_tokens, int) and max_tokens > 0:
return max_tokens
return None
return self._known_context_length_fallback(model_name)
def _supports_function_calling(self, model_name: str) -> bool:
return self._safe_litellm_bool_helper('supports_function_calling', model_name)
@@ -101,7 +164,11 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
return 'embedding'
return 'llm'
def _enrich_scanned_model(self, model_id: str) -> dict[str, typing.Any]:
def _enrich_scanned_model(
self,
model_id: str,
model_payload: dict[str, typing.Any] | None = None,
) -> dict[str, typing.Any]:
model_type = self._infer_model_type(model_id)
scanned_model: dict[str, typing.Any] = {
'id': model_id,
@@ -113,11 +180,17 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
abilities = []
if self._supports_function_calling(model_id):
abilities.append('func_call')
if self._supports_vision(model_id):
supports_provider_reported_vision = bool(
model_payload
and (model_payload.get('supports_image_in') is True or model_payload.get('supports_vision') is True)
)
if supports_provider_reported_vision or self._supports_vision(model_id):
abilities.append('vision')
scanned_model['abilities'] = abilities
context_length = self._safe_context_length(model_id)
context_length = self._context_length_from_scan_payload(model_payload)
if context_length is None:
context_length = self._safe_context_length(model_id)
if context_length is not None:
scanned_model['context_length'] = context_length
@@ -557,7 +630,7 @@ class LiteLLMRequester(requester.ProviderAPIRequester):
if not model_id:
continue
models.append(self._enrich_scanned_model(model_id))
models.append(self._enrich_scanned_model(model_id, item))
models.sort(key=lambda x: (x['type'] != 'llm', x['name'].lower()))

View File

@@ -2,7 +2,7 @@ import langbot
semantic_version = f'v{langbot.__version__}'
required_database_version = 25
required_database_version = 26
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
debug_mode = False

View File

@@ -35,6 +35,7 @@ def _create_mock_llm_model(
name: str = 'Test LLM',
provider_uuid: str = 'provider-uuid',
abilities: list = None,
context_length: int | None = None,
extra_args: dict = None,
) -> Mock:
"""Helper to create mock LLMModel entity."""
@@ -43,6 +44,7 @@ def _create_mock_llm_model(
model.name = name
model.provider_uuid = provider_uuid
model.abilities = abilities or []
model.context_length = context_length
model.extra_args = extra_args or {}
return model
@@ -142,10 +144,12 @@ class TestRuntimeModelData:
'name': 'Model',
'provider_uuid': 'provider',
'abilities': ['vision'],
'context_length': 128000,
'extra_args': {'temp': 0.7},
}
result = _runtime_model_data('uuid', update_payload)
assert result['abilities'] == ['vision']
assert result['context_length'] == 128000
assert result['extra_args'] == {'temp': 0.7}
@@ -188,7 +192,7 @@ class TestLLMModelsServiceGetLLMModels:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_llm_model()
model = _create_mock_llm_model(context_length=128000)
provider = _create_mock_provider()
mock_model_result = _create_mock_result([model])
@@ -206,6 +210,7 @@ class TestLLMModelsServiceGetLLMModels:
'uuid': entity.uuid,
'name': entity.name,
'provider_uuid': entity.provider_uuid if hasattr(entity, 'provider_uuid') else None,
'context_length': getattr(entity, 'context_length', None),
'api_keys': entity.api_keys if hasattr(entity, 'api_keys') else None,
}
)
@@ -218,6 +223,7 @@ class TestLLMModelsServiceGetLLMModels:
# Verify
assert len(result) == 1
assert result[0]['name'] == 'Test LLM'
assert result[0]['context_length'] == 128000
async def test_get_llm_models_hide_secret_keys(self):
"""Hides secret API keys when include_secret=False."""
@@ -265,7 +271,7 @@ class TestLLMModelsServiceGetLLMModel:
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
model = _create_mock_llm_model(model_uuid='found-uuid')
model = _create_mock_llm_model(model_uuid='found-uuid', context_length=128000)
provider = _create_mock_provider()
mock_model_result = _create_mock_result([], first_item=model)
@@ -279,11 +285,12 @@ class TestLLMModelsServiceGetLLMModel:
ap.persistence_mgr.execute_async = AsyncMock(side_effect=mock_execute)
ap.persistence_mgr.serialize_model = Mock(
return_value={
'uuid': 'found-uuid',
'name': 'Test LLM',
'provider_uuid': 'provider-uuid',
'provider': {'uuid': 'provider-uuid', 'api_keys': ['key']},
side_effect=lambda model_cls, entity: {
'uuid': entity.uuid,
'name': entity.name,
'provider_uuid': getattr(entity, 'provider_uuid', None),
'context_length': getattr(entity, 'context_length', None),
'api_keys': getattr(entity, 'api_keys', None),
}
)
@@ -295,6 +302,7 @@ class TestLLMModelsServiceGetLLMModel:
# Verify
assert result is not None
assert result['uuid'] == 'found-uuid'
assert result['context_length'] == 128000
async def test_get_llm_model_not_found(self):
"""Returns None when model not found."""
@@ -402,6 +410,39 @@ class TestLLMModelsServiceCreateLLMModel:
# Verify
assert model_uuid == 'preserved-uuid'
async def test_create_llm_model_persists_context_length_as_column(self):
"""Creates LLM model with context_length outside extra_args."""
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace()
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.llm_models = []
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
ap.pipeline_service = SimpleNamespace(update_pipeline=AsyncMock())
mock_result = _create_mock_result([])
ap.persistence_mgr.execute_async = AsyncMock(return_value=mock_result)
service = LLMModelsService(ap)
await service.create_llm_model(
{
'uuid': 'model-with-context',
'name': 'Context Model',
'provider_uuid': 'provider-uuid',
'abilities': ['func_call'],
'context_length': 128000,
'extra_args': {'temperature': 0.2},
},
preserve_uuid=True,
auto_set_to_default_pipeline=False,
)
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
assert runtime_entity.context_length == 128000
assert runtime_entity.extra_args == {'temperature': 0.2}
assert 'context_length' not in runtime_entity.extra_args
async def test_create_llm_model_provider_not_found_raises_error(self):
"""Raises Exception when provider not found in runtime."""
# Setup
@@ -512,6 +553,35 @@ class TestLLMModelsServiceUpdateLLMModel:
'provider_uuid': 'nonexistent-provider',
})
async def test_update_llm_model_reloads_context_length_as_column(self):
"""Updates runtime model with context_length outside extra_args."""
ap = SimpleNamespace()
ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
ap.model_mgr = SimpleNamespace()
ap.model_mgr.provider_dict = {'provider-uuid': Mock()}
ap.model_mgr.llm_models = []
ap.model_mgr.remove_llm_model = AsyncMock()
ap.model_mgr.load_llm_model_with_provider = AsyncMock(return_value=Mock())
service = LLMModelsService(ap)
await service.update_llm_model(
'existing-uuid',
{
'name': 'Updated Name',
'provider_uuid': 'provider-uuid',
'abilities': ['vision'],
'context_length': 64000,
'extra_args': {'temperature': 0.4},
},
)
runtime_entity = ap.model_mgr.load_llm_model_with_provider.await_args.args[0]
assert runtime_entity.uuid == 'existing-uuid'
assert runtime_entity.context_length == 64000
assert runtime_entity.extra_args == {'temperature': 0.4}
assert 'context_length' not in runtime_entity.extra_args
class TestLLMModelsServiceDeleteLLMModel:
"""Tests for LLMModelsService.delete_llm_model method."""
@@ -961,4 +1031,4 @@ class TestRerankModelsServiceGetRerankModelsByProvider:
result = await service.get_rerank_models_by_provider('provider-uuid')
# Verify
assert len(result) == 2
assert len(result) == 2

View File

@@ -896,6 +896,121 @@ class TestScanModels:
assert by_id['text-embedding-3-small']['type'] == 'embedding'
assert by_id['bge-reranker-v2']['type'] == 'rerank'
@pytest.mark.asyncio
async def test_scan_models_prefers_context_length_from_provider_payload(self):
"""Provider-supplied context_length is preserved before LiteLLM metadata fallback."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.moonshot.cn/v1',
'timeout': 60,
},
)
requester._supports_function_calling = Mock(return_value=False)
requester._supports_vision = Mock(return_value=False)
requester._safe_context_length = Mock(return_value=None)
mock_response = Mock()
mock_response.json = Mock(
return_value={
'data': [
{'id': 'moonshot-v1-128k', 'context_length': 131072},
]
}
)
mock_response.raise_for_status = Mock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
result = await requester.scan_models(api_key='test-key')
assert result['models'][0]['context_length'] == 131072
requester._safe_context_length.assert_not_called()
def test_safe_context_length_tries_moonshot_metadata_alias(self):
"""OpenAI-compatible Moonshot endpoints still use Moonshot metadata for context windows."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.moonshot.cn/v1',
'custom_llm_provider': 'openai',
},
)
with patch.object(litellmchat.litellm, 'get_max_tokens') as mock_get_max_tokens:
mock_get_max_tokens.side_effect = lambda model: 131072 if model == 'moonshot/moonshot-v1-128k' else None
assert requester._safe_context_length('moonshot-v1-128k') == 131072
def test_litellm_bool_helper_tries_moonshot_metadata_alias(self):
"""OpenAI-compatible Moonshot endpoints still use Moonshot metadata for abilities."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.moonshot.cn/v1',
'custom_llm_provider': 'openai',
},
)
with patch.object(litellmchat.litellm, 'supports_function_calling') as mock_supports_function_calling:
mock_supports_function_calling.side_effect = (
lambda model, custom_llm_provider=None: model == 'moonshot/kimi-k2.6'
and custom_llm_provider is None
)
assert requester._supports_function_calling('kimi-k2.6') is True
@pytest.mark.asyncio
async def test_scan_models_uses_provider_payload_for_vision_ability(self):
"""Provider-supplied vision support is used when scanning models."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.moonshot.cn/v1',
'timeout': 60,
},
)
requester._supports_function_calling = Mock(return_value=False)
requester._supports_vision = Mock(return_value=False)
requester._safe_context_length = Mock(return_value=None)
mock_response = Mock()
mock_response.json = Mock(
return_value={
'data': [
{
'id': 'moonshot-v1-128k-vision-preview',
'supports_image_in': True,
},
]
}
)
mock_response.raise_for_status = Mock()
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__ = AsyncMock(return_value=Mock())
mock_client.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
result = await requester.scan_models(api_key='test-key')
assert result['models'][0]['abilities'] == ['vision']
def test_safe_context_length_falls_back_for_deepseek_v4_models(self):
"""DeepSeek V4 API ids have a known 1M context even before LiteLLM maps them."""
requester = litellmchat.LiteLLMRequester(
ap=Mock(),
config={
'base_url': 'https://api.deepseek.com',
'custom_llm_provider': 'deepseek',
},
)
with patch.object(litellmchat.litellm, 'get_max_tokens', side_effect=Exception('not mapped')):
assert requester._safe_context_length('deepseek-v4-pro') == 1_000_000
assert requester._safe_context_length('deepseek-v4-flash') == 1_000_000
@pytest.mark.asyncio
async def test_scan_models_no_base_url(self):
"""Test scan_models without base_url raises error"""

View File

@@ -494,6 +494,7 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
'api_keys': ['temp-key'],
},
'abilities': ['func_call'],
'context_length': 128000,
'extra_args': {'temperature': 0.5},
}
@@ -501,6 +502,9 @@ async def test_model_manager_init_temporary_runtime_llm_model(fake_requester_reg
assert runtime_model.model_entity.uuid == 'temp-model-uuid'
assert runtime_model.model_entity.name == 'TempModel'
assert runtime_model.model_entity.context_length == 128000
assert runtime_model.model_entity.extra_args == {'temperature': 0.5}
assert 'context_length' not in runtime_model.model_entity.extra_args
assert runtime_model.provider.provider_entity.uuid == 'temp-provider-uuid'
assert runtime_model.provider.token_mgr.tokens == ['temp-key']
@@ -785,4 +789,4 @@ def test_provider_not_found_error_str():
error = provider_errors.ProviderNotFoundError('test-provider')
assert str(error) == 'Provider test-provider not found'
assert error.provider_name == 'test-provider'
assert error.provider_name == 'test-provider'

View File

@@ -64,6 +64,17 @@ function convertExtraArgsToObject(
return obj;
}
function parseContextLength(
value: number | null | undefined,
invalidMessage: string,
): number | null {
if (value === undefined || value === null) return null;
if (!Number.isInteger(value) || value <= 0) {
throw new Error(invalidMessage);
}
return value;
}
export default function ModelsDialog({
open,
onOpenChange,
@@ -254,6 +265,7 @@ export default function ModelsDialog({
name: string,
abilities: string[],
extraArgs: ExtraArg[],
contextLength?: number | null,
) {
if (!name.trim()) {
toast.error(t('models.modelNameRequired'));
@@ -268,6 +280,10 @@ export default function ModelsDialog({
name,
provider_uuid: providerUuid,
abilities,
context_length: parseContextLength(
contextLength,
t('models.contextLengthInvalid'),
),
extra_args: extraArgsObj,
} as never);
} else if (modelType === 'embedding') {
@@ -325,6 +341,7 @@ export default function ModelsDialog({
name: item.model.name,
provider_uuid: providerUuid,
abilities: item.abilities,
context_length: item.model.context_length ?? null,
extra_args: {},
} as never);
} else if (effectiveType === 'embedding') {
@@ -361,6 +378,7 @@ export default function ModelsDialog({
name: string,
abilities: string[],
extraArgs: ExtraArg[],
contextLength?: number | null,
) {
if (!name.trim()) {
toast.error(t('models.modelNameRequired'));
@@ -375,6 +393,10 @@ export default function ModelsDialog({
name,
provider_uuid: providerUuid,
abilities,
context_length: parseContextLength(
contextLength,
t('models.contextLengthInvalid'),
),
extra_args: extraArgsObj,
} as never);
} else if (modelType === 'embedding') {
@@ -509,8 +531,15 @@ export default function ModelsDialog({
onSpaceLogin={handleSpaceLogin}
onOpenAddModel={() => setAddModelPopoverOpen(provider.uuid)}
onCloseAddModel={() => setAddModelPopoverOpen(null)}
onAddModel={(modelType, name, abilities, extraArgs) =>
handleAddModel(provider.uuid, modelType, name, abilities, extraArgs)
onAddModel={(modelType, name, abilities, extraArgs, contextLength) =>
handleAddModel(
provider.uuid,
modelType,
name,
abilities,
extraArgs,
contextLength,
)
}
onScanModels={(modelType) => handleScanModels(provider.uuid, modelType)}
onAddScannedModels={(modelType, models) =>
@@ -518,7 +547,14 @@ export default function ModelsDialog({
}
onOpenEditModel={(modelId) => setEditModelPopoverOpen(modelId)}
onCloseEditModel={() => setEditModelPopoverOpen(null)}
onUpdateModel={(modelId, modelType, name, abilities, extraArgs) =>
onUpdateModel={(
modelId,
modelType,
name,
abilities,
extraArgs,
contextLength,
) =>
handleUpdateModel(
provider.uuid,
modelId,
@@ -526,6 +562,7 @@ export default function ModelsDialog({
name,
abilities,
extraArgs,
contextLength,
)
}
onOpenDeleteConfirm={(modelId) => setDeleteConfirmOpen(modelId)}

View File

@@ -41,6 +41,7 @@ interface AddModelPopoverProps {
name: string,
abilities: string[],
extraArgs: ExtraArg[],
contextLength?: number | null,
) => Promise<void>;
onScanModels: (modelType?: ModelType) => Promise<ScanModelsResult>;
onAddScannedModels: (
@@ -81,6 +82,7 @@ export default function AddModelPopover({
const [mode, setMode] = useState<'manual' | 'scan'>('manual');
const [name, setName] = useState('');
const [abilities, setAbilities] = useState<string[]>([]);
const [contextLength, setContextLength] = useState('');
const [extraArgs, setExtraArgs] = useState<ExtraArg[]>([]);
const [scanLoading, setScanLoading] = useState(false);
const [scannedModels, setScannedModels] = useState<ScannedProviderModel[]>(
@@ -98,6 +100,7 @@ export default function AddModelPopover({
setMode(initialMode);
setName('');
setAbilities([]);
setContextLength('');
setExtraArgs([]);
setScanLoading(false);
setScannedModels([]);
@@ -119,7 +122,11 @@ export default function AddModelPopover({
}, [tab, mode]);
const handleAdd = async () => {
await onAddModel(tab, name, abilities, extraArgs);
const parsedContextLength =
tab === 'llm' && contextLength.trim()
? Number(contextLength.trim())
: null;
await onAddModel(tab, name, abilities, extraArgs, parsedContextLength);
};
const handleTest = async () => {
@@ -318,6 +325,24 @@ export default function AddModelPopover({
</div>
)}
{tab === 'llm' && (
<div className="space-y-2">
<Label htmlFor="add-context-length">
{t('models.contextLength')}
</Label>
<Input
id="add-context-length"
type="number"
min={1}
step={1}
inputMode="numeric"
placeholder={t('models.contextLengthPlaceholder')}
value={contextLength}
onChange={(e) => setContextLength(e.target.value)}
/>
</div>
)}
<ExtraArgsEditor
args={extraArgs}
onChange={setExtraArgs}

View File

@@ -31,6 +31,7 @@ interface ModelItemProps {
name: string,
abilities: string[],
extraArgs: ExtraArg[],
contextLength?: number | null,
) => Promise<void>;
onTestModel: (
name: string,
@@ -92,6 +93,11 @@ export default function ModelItem({
const [editAbilities, setEditAbilities] = useState<string[]>(
modelType === 'llm' ? (model as LLMModel).abilities || [] : [],
);
const [editContextLength, setEditContextLength] = useState(
modelType === 'llm' && (model as LLMModel).context_length
? String((model as LLMModel).context_length)
: '',
);
const [editExtraArgs, setEditExtraArgs] = useState<ExtraArg[]>(
convertExtraArgsToArray(model.extra_args),
);
@@ -106,13 +112,27 @@ export default function ModelItem({
setEditAbilities(
modelType === 'llm' ? (model as LLMModel).abilities || [] : [],
);
setEditContextLength(
modelType === 'llm' && (model as LLMModel).context_length
? String((model as LLMModel).context_length)
: '',
);
setEditExtraArgs(convertExtraArgsToArray(model.extra_args));
onResetTestResult();
}
}, [isEditOpen]);
const handleSave = async () => {
await onUpdateModel(editName, editAbilities, editExtraArgs);
const parsedContextLength =
modelType === 'llm' && editContextLength.trim()
? Number(editContextLength.trim())
: null;
await onUpdateModel(
editName,
editAbilities,
editExtraArgs,
parsedContextLength,
);
};
const handleTest = async () => {
@@ -287,6 +307,25 @@ export default function ModelItem({
</div>
)}
{modelType === 'llm' && (
<div className="space-y-2">
<Label htmlFor={`edit-context-length-${model.uuid}`}>
{t('models.contextLength')}
</Label>
<Input
id={`edit-context-length-${model.uuid}`}
type="number"
min={1}
step={1}
inputMode="numeric"
placeholder={t('models.contextLengthPlaceholder')}
value={editContextLength}
disabled={isLangBotModels}
onChange={(e) => setEditContextLength(e.target.value)}
/>
</div>
)}
<ExtraArgsEditor
args={editExtraArgs}
onChange={setEditExtraArgs}

View File

@@ -60,6 +60,7 @@ interface ProviderCardProps {
name: string,
abilities: string[],
extraArgs: ExtraArg[],
contextLength?: number | null,
) => Promise<void>;
onScanModels: (modelType?: ModelType) => Promise<ScanModelsResult>;
onAddScannedModels: (
@@ -74,6 +75,7 @@ interface ProviderCardProps {
name: string,
abilities: string[],
extraArgs: ExtraArg[],
contextLength?: number | null,
) => Promise<void>;
onOpenDeleteConfirm: (modelId: string) => void;
onCloseDeleteConfirm: () => void;
@@ -405,13 +407,19 @@ export default function ProviderCard({
onOpenDeleteConfirm={onOpenDeleteConfirm}
onCloseDeleteConfirm={onCloseDeleteConfirm}
onDeleteModel={() => onDeleteModel(model.uuid, 'llm')}
onUpdateModel={(name, abilities, extraArgs) =>
onUpdateModel={(
name,
abilities,
extraArgs,
contextLength,
) =>
onUpdateModel(
model.uuid,
'llm',
name,
abilities,
extraArgs,
contextLength,
)
}
onTestModel={(name, abilities, extraArgs) =>

View File

@@ -53,6 +53,7 @@ export interface ModelItemProps {
name: string,
abilities: string[],
extraArgs: ExtraArg[],
contextLength?: number | null,
) => Promise<void>;
onTest: (
name: string,
@@ -89,6 +90,7 @@ export interface ProviderCardProps {
name: string,
abilities: string[],
extraArgs: ExtraArg[],
contextLength?: number | null,
) => Promise<void>;
onScanModels: (modelType?: ModelType) => Promise<ScanModelsResult>;
onAddScannedModels: (
@@ -103,6 +105,7 @@ export interface ProviderCardProps {
name: string,
abilities: string[],
extraArgs: ExtraArg[],
contextLength?: number | null,
) => Promise<void>;
onOpenDeleteConfirm: (modelId: string) => void;
onCloseDeleteConfirm: () => void;

View File

@@ -96,6 +96,7 @@ export interface LLMModel {
provider_uuid: string;
provider?: ModelProvider;
abilities?: string[];
context_length?: number | null;
extra_args?: object;
}

View File

@@ -201,6 +201,9 @@ const enUS = {
selectModelAbilities: 'Select model abilities',
visionAbility: 'Vision Ability',
functionCallAbility: 'Function Call',
contextLength: 'Context Window',
contextLengthPlaceholder: 'Unknown',
contextLengthInvalid: 'Context window must be a positive integer',
extraParameters: 'Extra Parameters',
addParameter: 'Add Parameter',
keyName: 'Key Name',

View File

@@ -206,6 +206,9 @@ const esES = {
selectModelAbilities: 'Seleccionar capacidades del modelo',
visionAbility: 'Capacidad de visión',
functionCallAbility: 'Llamada a funciones',
contextLength: 'Ventana de contexto',
contextLengthPlaceholder: 'Desconocido',
contextLengthInvalid: 'La ventana de contexto debe ser un entero positivo',
extraParameters: 'Parámetros adicionales',
addParameter: 'Añadir parámetro',
keyName: 'Nombre de la clave',

View File

@@ -204,6 +204,10 @@ const jaJP = {
selectModelAbilities: 'モデル機能を選択',
visionAbility: '視覚機能',
functionCallAbility: '関数呼び出し',
contextLength: 'コンテキストウィンドウ',
contextLengthPlaceholder: '不明',
contextLengthInvalid:
'コンテキストウィンドウは正の整数である必要があります',
extraParameters: '追加パラメータ',
addParameter: 'パラメータを追加',
keyName: 'キー名',

View File

@@ -203,6 +203,10 @@ const ruRU = {
selectModelAbilities: 'Выберите возможности модели',
visionAbility: 'Распознавание изображений',
functionCallAbility: 'Вызов функций',
contextLength: 'Контекстное окно',
contextLengthPlaceholder: 'Неизвестно',
contextLengthInvalid:
'Контекстное окно должно быть положительным целым числом',
extraParameters: 'Дополнительные параметры',
addParameter: 'Добавить параметр',
keyName: 'Имя ключа',

View File

@@ -199,6 +199,9 @@ const thTH = {
selectModelAbilities: 'เลือกความสามารถของโมเดล',
visionAbility: 'ความสามารถด้านภาพ',
functionCallAbility: 'การเรียกฟังก์ชัน',
contextLength: 'หน้าต่างบริบท',
contextLengthPlaceholder: 'ไม่ทราบ',
contextLengthInvalid: 'หน้าต่างบริบทต้องเป็นจำนวนเต็มบวก',
extraParameters: 'พารามิเตอร์เพิ่มเติม',
addParameter: 'เพิ่มพารามิเตอร์',
keyName: 'ชื่อคีย์',

View File

@@ -203,6 +203,9 @@ const viVN = {
selectModelAbilities: 'Chọn khả năng mô hình',
visionAbility: 'Khả năng thị giác',
functionCallAbility: 'Gọi hàm',
contextLength: 'Cửa sổ ngữ cảnh',
contextLengthPlaceholder: 'Không rõ',
contextLengthInvalid: 'Cửa sổ ngữ cảnh phải là số nguyên dương',
extraParameters: 'Tham số bổ sung',
addParameter: 'Thêm tham số',
keyName: 'Tên khóa',

View File

@@ -193,6 +193,9 @@ const zhHans = {
selectModelAbilities: '选择模型能力',
visionAbility: '视觉能力',
functionCallAbility: '函数调用',
contextLength: '上下文窗口',
contextLengthPlaceholder: '未知',
contextLengthInvalid: '上下文窗口必须是正整数',
extraParameters: '额外参数',
addParameter: '添加参数',
keyName: '键名',

View File

@@ -193,6 +193,9 @@ const zhHant = {
selectModelAbilities: '選擇模型能力',
visionAbility: '視覺能力',
functionCallAbility: '函數呼叫',
contextLength: '上下文視窗',
contextLengthPlaceholder: '未知',
contextLengthInvalid: '上下文視窗必須是正整數',
extraParameters: '額外參數',
addParameter: '新增參數',
keyName: '鍵名',