mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-04 04:54:36 +00:00
269 lines
11 KiB
Python
269 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
import traceback
|
|
|
|
import sqlalchemy
|
|
|
|
from ....core import app
|
|
from ....entity.persistence import model as persistence_model
|
|
|
|
|
|
class ModelProviderService:
|
|
"""Service for managing model providers"""
|
|
|
|
ap: app.Application
|
|
|
|
def __init__(self, ap: app.Application) -> None:
|
|
self.ap = ap
|
|
|
|
@staticmethod
|
|
def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]:
|
|
if api_keys is None:
|
|
return []
|
|
|
|
raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys)
|
|
normalized_keys = []
|
|
seen_keys = set()
|
|
|
|
for raw_key in raw_keys:
|
|
normalized_key = raw_key.strip() if isinstance(raw_key, str) else ''
|
|
if not normalized_key or normalized_key in seen_keys:
|
|
continue
|
|
normalized_keys.append(normalized_key)
|
|
seen_keys.add(normalized_key)
|
|
|
|
return normalized_keys
|
|
|
|
async def get_providers(self) -> list[dict]:
|
|
"""Get all providers"""
|
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
|
|
providers = result.all()
|
|
providers_list = []
|
|
for p in providers:
|
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, p)
|
|
# Parse api_keys if it's a JSON string
|
|
if isinstance(provider_dict.get('api_keys'), str):
|
|
import json
|
|
|
|
try:
|
|
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
|
|
except Exception:
|
|
provider_dict['api_keys'] = []
|
|
providers_list.append(provider_dict)
|
|
return providers_list
|
|
|
|
async def get_provider(self, provider_uuid: str) -> dict | None:
|
|
"""Get a single provider by UUID"""
|
|
result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
|
persistence_model.ModelProvider.uuid == provider_uuid
|
|
)
|
|
)
|
|
provider = result.first()
|
|
if provider is None:
|
|
return None
|
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
|
# Parse api_keys if it's a JSON string
|
|
if isinstance(provider_dict.get('api_keys'), str):
|
|
import json
|
|
|
|
try:
|
|
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
|
|
except Exception:
|
|
provider_dict['api_keys'] = []
|
|
return provider_dict
|
|
|
|
async def create_provider(self, provider_data: dict) -> str:
|
|
"""Create a new provider"""
|
|
provider_data['uuid'] = str(uuid.uuid4())
|
|
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
|
|
)
|
|
|
|
# load to runtime
|
|
runtime_provider = await self.ap.model_mgr.load_provider(provider_data)
|
|
self.ap.model_mgr.provider_dict[runtime_provider.provider_entity.uuid] = runtime_provider
|
|
return provider_data['uuid']
|
|
|
|
async def update_provider(self, provider_uuid: str, provider_data: dict) -> None:
|
|
"""Update an existing provider"""
|
|
if 'uuid' in provider_data:
|
|
del provider_data['uuid']
|
|
if 'api_keys' in provider_data:
|
|
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.update(persistence_model.ModelProvider)
|
|
.where(persistence_model.ModelProvider.uuid == provider_uuid)
|
|
.values(**provider_data)
|
|
)
|
|
await self.ap.model_mgr.reload_provider(provider_uuid)
|
|
|
|
async def delete_provider(self, provider_uuid: str) -> None:
|
|
"""Delete a provider (only if no models reference it)"""
|
|
# Check if any models use this provider
|
|
llm_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.LLMModel).where(
|
|
persistence_model.LLMModel.provider_uuid == provider_uuid
|
|
)
|
|
)
|
|
if llm_result.first() is not None:
|
|
raise ValueError('Cannot delete provider: LLM models still reference it')
|
|
|
|
embedding_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.EmbeddingModel).where(
|
|
persistence_model.EmbeddingModel.provider_uuid == provider_uuid
|
|
)
|
|
)
|
|
if embedding_result.first() is not None:
|
|
raise ValueError('Cannot delete provider: Embedding models still reference it')
|
|
|
|
rerank_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.RerankModel).where(
|
|
persistence_model.RerankModel.provider_uuid == provider_uuid
|
|
)
|
|
)
|
|
if rerank_result.first() is not None:
|
|
raise ValueError('Cannot delete provider: Rerank models still reference it')
|
|
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.delete(persistence_model.ModelProvider).where(
|
|
persistence_model.ModelProvider.uuid == provider_uuid
|
|
)
|
|
)
|
|
|
|
await self.ap.model_mgr.remove_provider(provider_uuid)
|
|
|
|
async def get_provider_model_counts(self, provider_uuid: str) -> dict:
|
|
"""Get count of models using this provider"""
|
|
llm_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(sqlalchemy.func.count())
|
|
.select_from(persistence_model.LLMModel)
|
|
.where(persistence_model.LLMModel.provider_uuid == provider_uuid)
|
|
)
|
|
llm_count = llm_result.scalar() or 0
|
|
|
|
embedding_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(sqlalchemy.func.count())
|
|
.select_from(persistence_model.EmbeddingModel)
|
|
.where(persistence_model.EmbeddingModel.provider_uuid == provider_uuid)
|
|
)
|
|
embedding_count = embedding_result.scalar() or 0
|
|
|
|
rerank_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(sqlalchemy.func.count())
|
|
.select_from(persistence_model.RerankModel)
|
|
.where(persistence_model.RerankModel.provider_uuid == provider_uuid)
|
|
)
|
|
rerank_count = rerank_result.scalar() or 0
|
|
|
|
return {'llm_count': llm_count, 'embedding_count': embedding_count, 'rerank_count': rerank_count}
|
|
|
|
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
|
"""Find existing provider or create new one"""
|
|
api_keys = self._normalize_api_keys(api_keys)
|
|
|
|
# Try to find existing provider with same config
|
|
result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
|
persistence_model.ModelProvider.requester == requester,
|
|
persistence_model.ModelProvider.base_url == base_url,
|
|
)
|
|
)
|
|
for provider in result.all():
|
|
if sorted(provider.api_keys or []) == sorted(api_keys or []):
|
|
return provider.uuid
|
|
|
|
# Create new provider
|
|
provider_name = requester
|
|
if base_url:
|
|
try:
|
|
from urllib.parse import urlparse
|
|
|
|
parsed = urlparse(base_url)
|
|
provider_name = parsed.netloc or requester
|
|
except Exception:
|
|
pass
|
|
|
|
return await self.create_provider(
|
|
{
|
|
'name': provider_name,
|
|
'requester': requester,
|
|
'base_url': base_url,
|
|
'api_keys': api_keys,
|
|
}
|
|
)
|
|
|
|
async def update_space_model_provider_api_keys(self, api_key: str) -> None:
|
|
"""Update Space model provider API keys"""
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.update(persistence_model.ModelProvider)
|
|
.where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000')
|
|
.values(api_keys=self._normalize_api_keys(api_key))
|
|
)
|
|
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')
|
|
|
|
async def scan_provider_models(self, provider_uuid: str, model_type: str | None = None) -> dict:
|
|
provider = await self.get_provider(provider_uuid)
|
|
if provider is None:
|
|
raise ValueError('provider not found')
|
|
|
|
runtime_provider = await self.ap.model_mgr.load_provider(provider)
|
|
|
|
try:
|
|
scan_result = await runtime_provider.requester.scan_models(
|
|
runtime_provider.token_mgr.get_token() if runtime_provider.token_mgr.tokens else None
|
|
)
|
|
except NotImplementedError:
|
|
raise ValueError('current provider does not support model scanning')
|
|
except Exception as exc:
|
|
self.ap.logger.warning(
|
|
f'Failed to scan models for provider {provider_uuid}: {exc}\n{traceback.format_exc()}'
|
|
)
|
|
raise ValueError(str(exc)) from exc
|
|
|
|
if isinstance(scan_result, dict):
|
|
scanned_models = scan_result.get('models', [])
|
|
debug_info = scan_result.get('debug')
|
|
else:
|
|
scanned_models = scan_result
|
|
debug_info = None
|
|
|
|
llm_models = await self.ap.llm_model_service.get_llm_models_by_provider(provider_uuid)
|
|
embedding_models = await self.ap.embedding_models_service.get_embedding_models_by_provider(provider_uuid)
|
|
existing_llm_names = {model['name'] for model in llm_models}
|
|
existing_embedding_names = {model['name'] for model in embedding_models}
|
|
|
|
filtered_models = []
|
|
for model in scanned_models:
|
|
scanned_type = model.get('type', 'llm')
|
|
if model_type and scanned_type != model_type:
|
|
continue
|
|
|
|
model_name = model.get('name') or model.get('id')
|
|
if not model_name:
|
|
continue
|
|
|
|
filtered_models.append(
|
|
{
|
|
'id': model.get('id', model_name),
|
|
'name': model_name,
|
|
'type': scanned_type,
|
|
'abilities': model.get('abilities', []),
|
|
'display_name': model.get('display_name'),
|
|
'description': model.get('description'),
|
|
'context_length': model.get('context_length'),
|
|
'owned_by': model.get('owned_by'),
|
|
'input_modalities': model.get('input_modalities', []),
|
|
'output_modalities': model.get('output_modalities', []),
|
|
'already_added': (
|
|
model_name in existing_embedding_names
|
|
if scanned_type == 'embedding'
|
|
else model_name in existing_llm_names
|
|
),
|
|
}
|
|
)
|
|
|
|
return {'models': filtered_models, 'debug': debug_info}
|