feat(models): add provider model scanning (#2106)

* feat(models): add provider model scanning

* fix: double close button

* feat: update plugin module

* fix(monitoring): WeChat Work feedback recording bugs (#2108)

* fix(monitoring): fix WeChat Work feedback recording bugs

- Fix feedback events silently dropped when stream session expires:
  dispatch feedback handlers regardless of session availability
- Fix IntegrityError on repeated feedback (like→dislike) for same
  message: implement UPSERT logic in record_feedback()
- Fix cancel feedback (type=3) not removing records: add delete logic
- Fix inaccurate_reasons validation error: convert int reason codes
  to strings before creating FeedbackEvent (Pydantic expects List[str])
- Fix feedback timestamps 8 hours off in frontend: use parseUTCTimestamp
  instead of new Date() for UTC timestamp parsing
- Fix StreamSessionManager.cleanup missing _feedback_index cleanup

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(monitoring): apply ruff format to wecom feedback files

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: 6mvp6 <13727783693@163.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* feat: add feat for receive files in wecombot

* fix: ruff error

* fix: always show sidebar plus buttons on touch/mobile devices (#2115)

Agent-Logs-Url: https://github.com/langbot-app/LangBot/sessions/e27a4886-fbad-4a7a-8558-67a387852753

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* fix: SPA fallback for all frontend routes, not just /home/*

After migrating from Next.js to Vite SPA, routes like /auth/space/callback
returned 404 because the static file server only had SPA fallback for /home/*.
Now all non-API routes fall back to index.html for React Router to handle.

* style: ruff format main.py

* feat: add marketplace link when no parser available for file upload

Links to /home/market?category=Parser, same pattern as knowledge engine selector.

* fix: lint error

* fix(user): allow password login and password change for Space accounts with local password set

Previously, Space accounts were unconditionally blocked from password login
and password change based on account_type. Now the check verifies whether
the user actually has a local password set, allowing Space users who have
set a local password to authenticate and change it normally.

* feat: add edition field to telemetry payload

Sends constants.edition (community/saas) with each telemetry event
so Space can distinguish between community and SaaS instances.

* style: ruff format telemetry.py

* fix(dingtalk): use voice recognition text instead of raw audio binary

When DingTalk sends a voice message to the bot, the callback JSON contains
a 'recognition' field with the speech-to-text result (powered by Qwen).

Previously, LangBot only extracted the 'downloadCode' to download the raw
audio binary and passed it as 'file_base64' to LLM APIs, which caused
400 errors since most models don't support this content type.

This patch:
- Extracts the 'recognition' field from DingTalk audio message content
- Uses it as plain text input to the LLM instead of raw audio
- Falls back to audio binary only when no recognition text is available
- Fixes duplicate text issue for audio messages with recognition

Fixes voice messages returning 'Request failed' on all LLM models.

* feat: integrate Alembic for database migrations

Replace manual if-sqlite/if-postgres branching with Alembic:
- Add alembic dependency
- Create programmatic alembic env (no CLI/alembic.ini needed)
- Support async engines via run_sync passthrough
- render_as_batch=True for SQLite ALTER TABLE compatibility
- Auto-stamp baseline on first run (existing DB at version 25)
- Run alembic upgrade head after legacy migrations
- Include sample migration showing schema + data migration patterns
- Add alembic dir to package-data for distribution

* ci: add migration test workflow for SQLite and PostgreSQL

Tests alembic upgrade on both databases:
- Stamp baseline on existing schema
- Upgrade to head
- Idempotent re-upgrade
- Fresh DB upgrade from scratch

* feat: add autogenerate support and CLI entrypoint for alembic

- autogenerate: compare ORM models vs DB schema to generate migrations
- CLI: python -m langbot.pkg.persistence.alembic_runner <command>
  - autogenerate, upgrade, stamp, current
- Reads data/config.yaml for DB connection

* fix: add filereader for dingtalk,lark (#2122)

* fix: add filereader for dingtalk

* feat: add lark

* feat: update uv.lock

* chore: update version to 4.9.6 in pyproject.toml, __init__.py, and uv.lock

* fix: update langbot-plugin version to 0.3.8

* fix: update langbot-plugin version to 0.3.8

* docs: update database migration instructions in AGENTS.md

* fix(dashscopeapi): fix null value check in reasoning content processing logic (#2128)

* fix(n8n-runner): fix output_key not applied when n8n returns plain JSON (#2119)

* fix: bump dependencies to resolve Dependabot security alerts (#2130)

* fix: bump dependencies to resolve Dependabot security alerts

Python:
- aiohttp: >=3.11.18 → >=3.13.4 (duplicate Host headers, header injection, redirect leak, multipart DoS)
- cryptography: >=44.0.3 → >=46.0.7 (buffer overflow with non-contiguous buffers)
- pillow: >=11.2.1 → >=12.2.0 (FITS GZIP decompression bomb, HIGH)
- langchain-text-splitters: >=0.0.1 → >=1.1.2 (SSRF redirect bypass)
- langchain-core: add >=1.2.28 (incomplete f-string validation)
- langsmith: add >=0.7.31 (streaming token redaction bypass)
- python-multipart: add >=0.0.26 (multipart DoS)
- Mako: add >=1.3.11 (path traversal)
- pytest: >=8.4.1 → >=9.0.3 (tmpdir handling)
- uv: >=0.7.11 → >=0.11.6 (arbitrary file deletion)

JavaScript (web/):
- vite: ^8.0.3 → ^8.0.5 (fs.deny bypass, WebSocket file read, path traversal, HIGH)
- axios: ^1.13.5 → ^1.15.0 (cloud metadata exfiltration)
- lodash: ^4.17.23 → ^4.18.0 (code injection via _.template, prototype pollution, HIGH)

* fix: update pnpm-lock.yaml for bumped dependencies

* feat(ci): add i18n key consistency check for frontend locales (#2133)

* feat(ci): add i18n key consistency check workflow

Agent-Logs-Url: https://github.com/langbot-app/LangBot/sessions/c7bf50da-189b-49a5-9671-dbe8e70ff9d0

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* feat(ci): replace eval with line-by-line parser, add permissions block

Agent-Logs-Url: https://github.com/langbot-app/LangBot/sessions/c7bf50da-189b-49a5-9671-dbe8e70ff9d0

Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>

* feat(models): add provider model scanning

* feat(models): add 'select all' functionality and enrich model abilities

* fix:ruff

* fix:ruff

---------

Co-authored-by: WangCham <651122857@qq.com>
Co-authored-by: 6mvp6 <119733319+6mvp6@users.noreply.github.com>
Co-authored-by: 6mvp6 <13727783693@163.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Guanchao Wang <wangcham233@gmail.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com>
Co-authored-by: RockChinQ <rockchinq@gmail.com>
Co-authored-by: haiyangbg <zhouhaiyangaa@gmail.com>
Co-authored-by: Rock Chin <1010553892@qq.com>
Co-authored-by: Amadeus <115918672+AmadeusKurisu1@users.noreply.github.com>
Co-authored-by: hzhhong <hung.z.h916@gmail.com>
Co-authored-by: fdc310 <2213070223@qq.com>
This commit is contained in:
sheetung
2026-04-19 17:47:07 +08:00
committed by GitHub
parent 05c684d757
commit 4e5a6ee79a
23 changed files with 1213 additions and 115 deletions

View File

@@ -43,3 +43,12 @@ class ModelProvidersRouterGroup(group.RouterGroup):
return self.success()
except ValueError as e:
return self.http_status(400, -1, str(e))
@self.route('/<provider_uuid>/scan-models', methods=['GET'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
async def _(provider_uuid: str) -> str:
try:
model_type = quart.request.args.get('type')
result = await self.ap.provider_service.scan_provider_models(provider_uuid, model_type)
return self.success(data=result)
except ValueError as e:
return self.http_status(400, -1, str(e))

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import uuid
import traceback
import sqlalchemy
@@ -164,3 +165,66 @@ class ModelProviderService:
.values(api_keys=[api_key])
)
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')
async def scan_provider_models(self, provider_uuid: str, model_type: str | None = None) -> dict:
provider = await self.get_provider(provider_uuid)
if provider is None:
raise ValueError('provider not found')
runtime_provider = await self.ap.model_mgr.load_provider(provider)
try:
scan_result = await runtime_provider.requester.scan_models(
runtime_provider.token_mgr.get_token() if runtime_provider.token_mgr.tokens else None
)
except NotImplementedError:
raise ValueError('current provider does not support model scanning')
except Exception as exc:
self.ap.logger.warning(
f'Failed to scan models for provider {provider_uuid}: {exc}\n{traceback.format_exc()}'
)
raise ValueError(str(exc)) from exc
if isinstance(scan_result, dict):
scanned_models = scan_result.get('models', [])
debug_info = scan_result.get('debug')
else:
scanned_models = scan_result
debug_info = None
llm_models = await self.ap.llm_model_service.get_llm_models_by_provider(provider_uuid)
embedding_models = await self.ap.embedding_models_service.get_embedding_models_by_provider(provider_uuid)
existing_llm_names = {model['name'] for model in llm_models}
existing_embedding_names = {model['name'] for model in embedding_models}
filtered_models = []
for model in scanned_models:
scanned_type = model.get('type', 'llm')
if model_type and scanned_type != model_type:
continue
model_name = model.get('name') or model.get('id')
if not model_name:
continue
filtered_models.append(
{
'id': model.get('id', model_name),
'name': model_name,
'type': scanned_type,
'abilities': model.get('abilities', []),
'display_name': model.get('display_name'),
'description': model.get('description'),
'context_length': model.get('context_length'),
'owned_by': model.get('owned_by'),
'input_modalities': model.get('input_modalities', []),
'output_modalities': model.get('output_modalities', []),
'already_added': (
model_name in existing_embedding_names
if scanned_type == 'embedding'
else model_name in existing_llm_names
),
}
)
return {'models': filtered_models, 'debug': debug_info}

View File

@@ -227,7 +227,8 @@ class ModelManager:
raise provider_errors.RequesterNotFoundError(provider_entity.requester)
requester_inst = self.requester_dict[provider_entity.requester](
ap=self.ap, config={'base_url': provider_entity.base_url}
ap=self.ap,
config={'base_url': provider_entity.base_url},
)
await requester_inst.initialize()

View File

@@ -303,6 +303,14 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
async def initialize(self):
pass
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any] | list[dict[str, typing.Any]]:
"""Scan models supported by the provider.
The default implementation does not support scanning. Requesters that
can enumerate remote models should override this method.
"""
raise NotImplementedError('This provider does not support model scanning')
@abc.abstractmethod
async def invoke_llm(
self,

View File

@@ -31,6 +31,192 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
)
def _mask_api_key(self, api_key: str | None) -> str:
if not api_key:
return ''
if len(api_key) <= 8:
return '****'
return f'{api_key[:4]}...{api_key[-4:]}'
def _infer_model_type(self, model_id: str) -> str:
normalized_model_id = (model_id or '').lower()
embedding_keywords = (
'embedding',
'embed',
'bge-',
'e5-',
'm3e',
'gte-',
'multilingual-e5',
'text-embedding',
)
return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm'
def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]:
normalized_model_id = (model_id or '').lower()
abilities: set[str] = set()
def _flatten(value: typing.Any) -> list[str]:
if value is None:
return []
if isinstance(value, str):
return [value.lower()]
if isinstance(value, dict):
flattened: list[str] = []
for nested_value in value.values():
flattened.extend(_flatten(nested_value))
return flattened
if isinstance(value, (list, tuple, set)):
flattened: list[str] = []
for nested_value in value:
flattened.extend(_flatten(nested_value))
return flattened
return [str(value).lower()]
capability_tokens = _flatten(item.get('capabilities'))
capability_tokens.extend(_flatten(item.get('modalities')))
capability_tokens.extend(_flatten(item.get('input_modalities')))
capability_tokens.extend(_flatten(item.get('output_modalities')))
capability_tokens.extend(_flatten(item.get('supported_generation_methods')))
capability_tokens.extend(_flatten(item.get('supported_parameters')))
capability_tokens.extend(_flatten(item.get('architecture')))
combined_tokens = capability_tokens + [normalized_model_id]
vision_keywords = (
'vision',
'image',
'file',
'video',
'multimodal',
'vl',
'ocr',
'omni',
)
function_call_keywords = (
'function',
'tool',
'tools',
'tool_choice',
'tool_call',
'tool-use',
'tool_use',
)
if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens):
abilities.add('vision')
if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens):
abilities.add('func_call')
return sorted(abilities)
def _normalize_modalities(self, value: typing.Any) -> list[str]:
normalized: list[str] = []
def _collect(item: typing.Any):
if item is None:
return
if isinstance(item, str):
for part in item.replace('->', ',').replace('+', ',').split(','):
token = part.strip().lower()
if token and token not in normalized:
normalized.append(token)
return
if isinstance(item, dict):
for nested in item.values():
_collect(nested)
return
if isinstance(item, (list, tuple, set)):
for nested in item:
_collect(nested)
return
_collect(value)
return normalized
def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]:
display_name = item.get('name')
if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id:
display_name = ''
description = item.get('description')
if not isinstance(description, str) or not description.strip():
description = ''
context_length = item.get('context_length')
if context_length is None and isinstance(item.get('top_provider'), dict):
context_length = item['top_provider'].get('context_length')
if not isinstance(context_length, int):
try:
context_length = int(context_length) if context_length is not None else None
except (TypeError, ValueError):
context_length = None
input_modalities = self._normalize_modalities(item.get('input_modalities'))
output_modalities = self._normalize_modalities(item.get('output_modalities'))
if isinstance(item.get('architecture'), dict):
if not input_modalities:
input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities'))
if not output_modalities:
output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities'))
owned_by = item.get('owned_by')
if not isinstance(owned_by, str) or not owned_by.strip():
owned_by = ''
return {
'display_name': display_name or None,
'description': description or None,
'context_length': context_length,
'owned_by': owned_by or None,
'input_modalities': input_modalities,
'output_modalities': output_modalities,
}
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
headers = {}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models'
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
response = await client.get(models_url, headers=headers)
response.raise_for_status()
payload = response.json()
models = []
for item in payload.get('data', []):
model_id = item.get('id')
if not model_id:
continue
models.append(
{
'id': model_id,
'name': model_id,
'type': self._infer_model_type(model_id),
'abilities': self._infer_model_abilities(item, model_id),
**self._extract_scan_metadata(item, model_id),
}
)
models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
return {
'models': models,
'debug': {
'request': {
'method': 'GET',
'url': models_url,
'headers': {
'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '',
},
},
'response': payload,
},
}
async def _req(
self,
args: dict,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import typing
import httpx
from . import chatcmpl
@@ -20,6 +21,68 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions):
'timeout': 120,
}
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
models_url = 'https://generativelanguage.googleapis.com/v1beta/models'
params = {'key': api_key} if api_key else {}
all_models: list[dict[str, typing.Any]] = []
next_page_token = ''
last_payload: dict[str, typing.Any] = {}
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
while True:
request_params = dict(params)
if next_page_token:
request_params['pageToken'] = next_page_token
response = await client.get(models_url, params=request_params)
response.raise_for_status()
payload = response.json()
last_payload = payload
for item in payload.get('models', []):
model_name = item.get('name', '')
model_id = model_name.replace('models/', '', 1)
if not model_id:
continue
supported_methods = item.get('supportedGenerationMethods', []) or []
if 'embedContent' in supported_methods and 'generateContent' not in supported_methods:
model_type = 'embedding'
else:
model_type = 'llm'
all_models.append(
{
'id': model_id,
'name': model_id,
'type': model_type,
'abilities': self._infer_model_abilities(item, model_id),
'display_name': item.get('displayName') or None,
'description': item.get('description') or None,
'context_length': item.get('inputTokenLimit'),
'input_modalities': self._normalize_modalities(item.get('inputModalities')),
'output_modalities': self._normalize_modalities(item.get('outputModalities')),
}
)
next_page_token = payload.get('nextPageToken', '')
if not next_page_token:
break
all_models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
return {
'models': all_models,
'debug': {
'request': {
'method': 'GET',
'url': models_url,
'query': {'key': self._mask_api_key(api_key)} if api_key else {},
},
'response': last_payload,
},
}
async def _closure_stream(
self,
query: pipeline_query.Query,

View File

@@ -31,6 +31,175 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
)
def _mask_api_key(self, api_key: str | None) -> str:
if not api_key:
return ''
if len(api_key) <= 8:
return '****'
return f'{api_key[:4]}...{api_key[-4:]}'
def _infer_model_type(self, model_id: str) -> str:
normalized_model_id = (model_id or '').lower()
embedding_keywords = (
'embedding',
'embed',
'bge-',
'e5-',
'm3e',
'gte-',
'multilingual-e5',
'text-embedding',
)
return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm'
def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]:
normalized_model_id = (model_id or '').lower()
abilities: set[str] = set()
def _flatten(value: typing.Any) -> list[str]:
if value is None:
return []
if isinstance(value, str):
return [value.lower()]
if isinstance(value, dict):
flattened: list[str] = []
for nested_value in value.values():
flattened.extend(_flatten(nested_value))
return flattened
if isinstance(value, (list, tuple, set)):
flattened: list[str] = []
for nested_value in value:
flattened.extend(_flatten(nested_value))
return flattened
return [str(value).lower()]
capability_tokens = _flatten(item.get('capabilities'))
capability_tokens.extend(_flatten(item.get('modalities')))
capability_tokens.extend(_flatten(item.get('input_modalities')))
capability_tokens.extend(_flatten(item.get('output_modalities')))
capability_tokens.extend(_flatten(item.get('supported_generation_methods')))
capability_tokens.extend(_flatten(item.get('supported_parameters')))
capability_tokens.extend(_flatten(item.get('architecture')))
combined_tokens = capability_tokens + [normalized_model_id]
vision_keywords = ('vision', 'image', 'file', 'video', 'multimodal', 'vl', 'ocr', 'omni')
function_call_keywords = ('function', 'tool', 'tools', 'tool_choice', 'tool_call', 'tool-use', 'tool_use')
if any(any(keyword in token for keyword in vision_keywords) for token in combined_tokens):
abilities.add('vision')
if any(any(keyword in token for keyword in function_call_keywords) for token in combined_tokens):
abilities.add('func_call')
return sorted(abilities)
def _normalize_modalities(self, value: typing.Any) -> list[str]:
normalized: list[str] = []
def _collect(item: typing.Any):
if item is None:
return
if isinstance(item, str):
for part in item.replace('->', ',').replace('+', ',').split(','):
token = part.strip().lower()
if token and token not in normalized:
normalized.append(token)
return
if isinstance(item, dict):
for nested in item.values():
_collect(nested)
return
if isinstance(item, (list, tuple, set)):
for nested in item:
_collect(nested)
return
_collect(value)
return normalized
def _extract_scan_metadata(self, item: dict[str, typing.Any], model_id: str) -> dict[str, typing.Any]:
display_name = item.get('name')
if not isinstance(display_name, str) or not display_name.strip() or display_name == model_id:
display_name = ''
description = item.get('description')
if not isinstance(description, str) or not description.strip():
description = ''
context_length = item.get('context_length')
if context_length is None and isinstance(item.get('top_provider'), dict):
context_length = item['top_provider'].get('context_length')
if not isinstance(context_length, int):
try:
context_length = int(context_length) if context_length is not None else None
except (TypeError, ValueError):
context_length = None
input_modalities = self._normalize_modalities(item.get('input_modalities'))
output_modalities = self._normalize_modalities(item.get('output_modalities'))
if isinstance(item.get('architecture'), dict):
if not input_modalities:
input_modalities = self._normalize_modalities(item['architecture'].get('input_modalities'))
if not output_modalities:
output_modalities = self._normalize_modalities(item['architecture'].get('output_modalities'))
owned_by = item.get('owned_by')
if not isinstance(owned_by, str) or not owned_by.strip():
owned_by = ''
return {
'display_name': display_name or None,
'description': description or None,
'context_length': context_length,
'owned_by': owned_by or None,
'input_modalities': input_modalities,
'output_modalities': output_modalities,
}
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
headers = {}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/models'
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
response = await client.get(models_url, headers=headers)
response.raise_for_status()
payload = response.json()
models = []
for item in payload.get('data', []):
model_id = item.get('id')
if not model_id:
continue
models.append(
{
'id': model_id,
'name': model_id,
'type': self._infer_model_type(model_id),
'abilities': self._infer_model_abilities(item, model_id),
**self._extract_scan_metadata(item, model_id),
}
)
models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
return {
'models': models,
'debug': {
'request': {
'method': 'GET',
'url': models_url,
'headers': {
'Authorization': f'Bearer {self._mask_api_key(api_key)}' if api_key else '',
},
},
'response': payload,
},
}
async def _req(
self,
query: pipeline_query.Query,

View File

@@ -8,6 +8,7 @@ import uuid
import json
import ollama
import httpx
from .. import errors, requester
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
@@ -31,6 +32,60 @@ class OllamaChatCompletions(requester.ProviderAPIRequester):
os.environ['OLLAMA_HOST'] = self.requester_cfg['base_url']
self.client = ollama.AsyncClient(timeout=self.requester_cfg['timeout'])
def _infer_model_type(self, model_id: str) -> str:
normalized_model_id = (model_id or '').lower()
embedding_keywords = ('embedding', 'embed', 'bge-', 'e5-', 'm3e', 'gte-', 'text-embedding')
return 'embedding' if any(keyword in normalized_model_id for keyword in embedding_keywords) else 'llm'
def _infer_model_abilities(self, item: dict[str, typing.Any], model_id: str) -> list[str]:
normalized_model_id = (model_id or '').lower()
abilities: set[str] = set()
details = item.get('details', {}) or {}
families = details.get('families', []) or []
tokens = [normalized_model_id, str(details.get('family', '')).lower()]
tokens.extend(str(family).lower() for family in families)
if any(keyword in token for token in tokens for keyword in ('vision', 'vl', 'omni', 'llava', 'ocr')):
abilities.add('vision')
if any(keyword in token for token in tokens for keyword in ('tool', 'function')):
abilities.add('func_call')
return sorted(abilities)
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
del api_key
models_url = f'{self.requester_cfg["base_url"].rstrip("/")}/api/tags'
async with httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']) as client:
response = await client.get(models_url)
response.raise_for_status()
payload = response.json()
models: list[dict[str, typing.Any]] = []
for item in payload.get('models', []):
model_id = item.get('model') or item.get('name')
if not model_id:
continue
models.append(
{
'id': model_id,
'name': item.get('name', model_id),
'type': self._infer_model_type(model_id),
'abilities': self._infer_model_abilities(item, model_id),
}
)
models.sort(key=lambda item: (item['type'] != 'llm', item['name'].lower()))
return {
'models': models,
'debug': {
'request': {
'method': 'GET',
'url': models_url,
},
'response': payload,
},
}
async def _req(
self,
args: dict,

View File

@@ -15,3 +15,11 @@ class OpenRouterChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions):
'base_url': 'https://openrouter.ai/api/v1',
'timeout': 120,
}
async def scan_models(self, api_key: str | None = None) -> dict[str, typing.Any]:
original_base_url = self.requester_cfg.get('base_url', '')
self.requester_cfg['base_url'] = 'https://openrouter.ai/api/v1'
try:
return await super().scan_models(api_key)
finally:
self.requester_cfg['base_url'] = original_base_url