mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-05 05:16:03 +00:00
- Replace 23+ individual requester implementations with unified litellmchat.py - Add litellm_provider field to 27 YAML manifests for provider routing - Delete redundant requester subclasses - Add unit tests for LiteLLMRequester (29 tests) - Fix num_retries parameter name (was max_retries) - Fix exception handling order for subclass exceptions LiteLLM provides unified API for 100+ providers, eliminating need for provider-specific requesters.
540 lines
22 KiB
Python
540 lines
22 KiB
Python
from __future__ import annotations
|
|
|
|
import sqlalchemy
|
|
import traceback
|
|
|
|
from . import requester
|
|
from .requesters import litellmchat
|
|
from ...core import app
|
|
from ...discover import engine
|
|
from . import token
|
|
from ...entity.persistence import model as persistence_model
|
|
from ...entity.errors import provider as provider_errors
|
|
|
|
|
|
class ModelManager:
|
|
"""Model manager"""
|
|
|
|
ap: app.Application
|
|
|
|
provider_dict: dict[str, requester.RuntimeProvider]
|
|
"""运行时模型提供商字典, uuid -> RuntimeProvider"""
|
|
|
|
llm_models: list[requester.RuntimeLLMModel]
|
|
|
|
embedding_models: list[requester.RuntimeEmbeddingModel]
|
|
|
|
rerank_models: list[requester.RuntimeRerankModel]
|
|
|
|
requester_components: list[engine.Component]
|
|
|
|
requester_dict: dict[str, type[requester.ProviderAPIRequester]]
|
|
|
|
def __init__(self, ap: app.Application):
|
|
self.ap = ap
|
|
self.llm_models = []
|
|
self.embedding_models = []
|
|
self.rerank_models = []
|
|
self.requester_components = []
|
|
self.requester_dict = {}
|
|
|
|
async def initialize(self):
|
|
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
|
|
|
|
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
|
|
for component in self.requester_components:
|
|
# Skip components that use litellm_provider (they will use litellmchat.py instead)
|
|
if component.spec.get('litellm_provider'):
|
|
self.ap.logger.debug(
|
|
f'Skipping Python class loading for {component.metadata.name} '
|
|
f'(uses litellm_provider={component.spec.get("litellm_provider")})'
|
|
)
|
|
continue
|
|
requester_dict[component.metadata.name] = component.get_python_component_class()
|
|
|
|
self.requester_dict = requester_dict
|
|
|
|
await self.load_models_from_db()
|
|
|
|
# Check if space models service is disabled
|
|
space_config = self.ap.instance_config.data.get('space', {})
|
|
if space_config.get('disable_models_service', False):
|
|
self.ap.logger.info('LangBot Space Models service is disabled, skipping sync.')
|
|
return
|
|
|
|
try:
|
|
await self.sync_new_models_from_space()
|
|
except Exception as e:
|
|
self.ap.logger.warning('Failed to sync new models from LangBot Space, model list may not be updated.')
|
|
self.ap.logger.warning(f' - Error: {e}')
|
|
|
|
async def load_models_from_db(self):
|
|
"""Load models from database"""
|
|
self.ap.logger.info('Loading models from db...')
|
|
|
|
self.llm_models = []
|
|
self.embedding_models = []
|
|
self.rerank_models = []
|
|
self.provider_dict = {}
|
|
providers_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider)
|
|
)
|
|
for provider in providers_result.all():
|
|
try:
|
|
runtime_provider = await self.load_provider(provider)
|
|
self.provider_dict[provider.uuid] = runtime_provider
|
|
except provider_errors.RequesterNotFoundError as e:
|
|
self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping provider {provider.uuid}')
|
|
continue
|
|
except Exception as e:
|
|
self.ap.logger.error(f'Failed to load provider {provider.uuid}: {e}\n{traceback.format_exc()}')
|
|
|
|
# Load LLM models
|
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
|
llm_models = result.all()
|
|
for llm_model in llm_models:
|
|
try:
|
|
provider = self.provider_dict.get(llm_model.provider_uuid)
|
|
if provider is None:
|
|
self.ap.logger.warning(f'Provider {llm_model.provider_uuid} not found for model {llm_model.uuid}')
|
|
continue
|
|
runtime_llm_model = await self.load_llm_model_with_provider(llm_model, provider)
|
|
self.llm_models.append(runtime_llm_model)
|
|
except Exception as e:
|
|
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
|
|
|
|
# Load embedding models
|
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
|
|
embedding_models = result.all()
|
|
for embedding_model in embedding_models:
|
|
try:
|
|
provider = self.provider_dict.get(embedding_model.provider_uuid)
|
|
if provider is None:
|
|
self.ap.logger.warning(
|
|
f'Provider {embedding_model.provider_uuid} not found for model {embedding_model.uuid}'
|
|
)
|
|
continue
|
|
runtime_embedding_model = await self.load_embedding_model_with_provider(embedding_model, provider)
|
|
self.embedding_models.append(runtime_embedding_model)
|
|
except Exception as e:
|
|
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
|
|
|
|
# Load rerank models
|
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel))
|
|
rerank_models = result.all()
|
|
for rerank_model in rerank_models:
|
|
try:
|
|
provider = self.provider_dict.get(rerank_model.provider_uuid)
|
|
if provider is None:
|
|
self.ap.logger.warning(
|
|
f'Provider {rerank_model.provider_uuid} not found for model {rerank_model.uuid}'
|
|
)
|
|
continue
|
|
runtime_rerank_model = await self.load_rerank_model_with_provider(rerank_model, provider)
|
|
self.rerank_models.append(runtime_rerank_model)
|
|
except Exception as e:
|
|
self.ap.logger.error(f'Failed to load model {rerank_model.uuid}: {e}\n{traceback.format_exc()}')
|
|
|
|
async def sync_new_models_from_space(self):
|
|
"""Sync models from Space"""
|
|
space_model_provider = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
|
persistence_model.ModelProvider.requester == 'space-chat-completions'
|
|
)
|
|
)
|
|
result = space_model_provider.first()
|
|
if result is None:
|
|
raise provider_errors.ProviderNotFoundError('LangBot Models')
|
|
|
|
space_model_provider = result
|
|
|
|
# get the latest models from space
|
|
space_models = await self.ap.space_service.get_models()
|
|
|
|
# Index existing models by uuid. Space reuses a model's uuid across
|
|
# renames / re-specs (e.g. the uuid that used to be ``claude-opus-4-6``
|
|
# may later become ``claude-opus-4-7``). So for Space-managed models we
|
|
# upsert: create when the uuid is new, otherwise update name/abilities/
|
|
# ranking to track Space. Models owned by other providers are never
|
|
# touched, even on an (unexpected) uuid collision.
|
|
existing_llm_models = {m['uuid']: m for m in await self.ap.llm_model_service.get_llm_models()}
|
|
existing_embedding_models = {
|
|
m['uuid']: m for m in await self.ap.embedding_models_service.get_embedding_models()
|
|
}
|
|
|
|
created = 0
|
|
updated = 0
|
|
|
|
for space_model in space_models:
|
|
if space_model.category == 'chat':
|
|
existing = existing_llm_models.get(space_model.uuid)
|
|
if existing is None:
|
|
# model will be automatically loaded
|
|
await self.ap.llm_model_service.create_llm_model(
|
|
{
|
|
'uuid': space_model.uuid,
|
|
'name': space_model.model_id,
|
|
'provider_uuid': space_model_provider.uuid,
|
|
'abilities': space_model.llm_abilities or [],
|
|
'extra_args': {},
|
|
'prefered_ranking': space_model.featured_order,
|
|
},
|
|
preserve_uuid=True,
|
|
auto_set_to_default_pipeline=False,
|
|
)
|
|
created += 1
|
|
elif existing.get('provider_uuid') == space_model_provider.uuid:
|
|
desired = {
|
|
'name': space_model.model_id,
|
|
'provider_uuid': space_model_provider.uuid,
|
|
'abilities': space_model.llm_abilities or [],
|
|
'prefered_ranking': space_model.featured_order,
|
|
}
|
|
if (
|
|
existing.get('name') != desired['name']
|
|
or list(existing.get('abilities') or []) != list(desired['abilities'])
|
|
or existing.get('prefered_ranking') != desired['prefered_ranking']
|
|
):
|
|
await self.ap.llm_model_service.update_llm_model(space_model.uuid, dict(desired))
|
|
updated += 1
|
|
|
|
elif space_model.category == 'embedding':
|
|
existing = existing_embedding_models.get(space_model.uuid)
|
|
if existing is None:
|
|
# model will be automatically loaded
|
|
await self.ap.embedding_models_service.create_embedding_model(
|
|
{
|
|
'uuid': space_model.uuid,
|
|
'name': space_model.model_id,
|
|
'provider_uuid': space_model_provider.uuid,
|
|
'extra_args': {},
|
|
'prefered_ranking': space_model.featured_order,
|
|
},
|
|
preserve_uuid=True,
|
|
)
|
|
created += 1
|
|
elif existing.get('provider_uuid') == space_model_provider.uuid:
|
|
desired = {
|
|
'name': space_model.model_id,
|
|
'provider_uuid': space_model_provider.uuid,
|
|
'prefered_ranking': space_model.featured_order,
|
|
}
|
|
if (
|
|
existing.get('name') != desired['name']
|
|
or existing.get('prefered_ranking') != desired['prefered_ranking']
|
|
):
|
|
await self.ap.embedding_models_service.update_embedding_model(space_model.uuid, dict(desired))
|
|
updated += 1
|
|
|
|
if created or updated:
|
|
self.ap.logger.info(f'Synced models from LangBot Space: {created} added, {updated} updated.')
|
|
|
|
async def init_temporary_runtime_llm_model(
|
|
self,
|
|
model_info: dict,
|
|
) -> requester.RuntimeLLMModel:
|
|
"""Initialize runtime LLM model from dict (for testing)"""
|
|
provider_info = model_info.get('provider', {})
|
|
|
|
runtime_provider = await self.load_provider(provider_info)
|
|
|
|
runtime_llm_model = requester.RuntimeLLMModel(
|
|
model_entity=persistence_model.LLMModel(
|
|
uuid=model_info.get('uuid', ''),
|
|
name=model_info.get('name', ''),
|
|
provider_uuid='',
|
|
abilities=model_info.get('abilities', []),
|
|
extra_args=model_info.get('extra_args', {}),
|
|
),
|
|
provider=runtime_provider,
|
|
)
|
|
|
|
return runtime_llm_model
|
|
|
|
async def init_temporary_runtime_embedding_model(
|
|
self,
|
|
model_info: dict,
|
|
) -> requester.RuntimeEmbeddingModel:
|
|
"""Initialize runtime embedding model from dict (for testing)"""
|
|
provider_info = model_info.get('provider', {})
|
|
runtime_provider = await self.load_provider(provider_info)
|
|
|
|
runtime_embedding_model = requester.RuntimeEmbeddingModel(
|
|
model_entity=persistence_model.EmbeddingModel(
|
|
uuid=model_info.get('uuid', ''),
|
|
name=model_info.get('name', ''),
|
|
provider_uuid='',
|
|
extra_args=model_info.get('extra_args', {}),
|
|
),
|
|
provider=runtime_provider,
|
|
)
|
|
|
|
return runtime_embedding_model
|
|
|
|
async def init_temporary_runtime_rerank_model(
|
|
self,
|
|
model_info: dict,
|
|
) -> requester.RuntimeRerankModel:
|
|
"""Initialize runtime rerank model from dict (for testing)"""
|
|
provider_info = model_info.get('provider', {})
|
|
runtime_provider = await self.load_provider(provider_info)
|
|
|
|
runtime_rerank_model = requester.RuntimeRerankModel(
|
|
model_entity=persistence_model.RerankModel(
|
|
uuid=model_info.get('uuid', ''),
|
|
name=model_info.get('name', ''),
|
|
provider_uuid='',
|
|
extra_args=model_info.get('extra_args', {}),
|
|
),
|
|
provider=runtime_provider,
|
|
)
|
|
|
|
return runtime_rerank_model
|
|
|
|
async def load_provider(
|
|
self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict
|
|
) -> requester.RuntimeProvider:
|
|
"""Load provider from dict"""
|
|
if isinstance(provider_info, sqlalchemy.Row):
|
|
provider_entity = persistence_model.ModelProvider(**provider_info._mapping)
|
|
elif isinstance(provider_info, dict):
|
|
provider_entity = persistence_model.ModelProvider(**provider_info)
|
|
else:
|
|
provider_entity = provider_info
|
|
|
|
# Get requester manifest to check for litellm_provider
|
|
requester_manifest = self.get_available_requester_manifest_by_name(provider_entity.requester)
|
|
|
|
# Build config from base_url
|
|
config = {'base_url': provider_entity.base_url}
|
|
|
|
# Check if requester manifest specifies litellm_provider
|
|
if requester_manifest and requester_manifest.spec.get('litellm_provider'):
|
|
# Use unified LiteLLMRequester with provider prefix
|
|
# Map litellm_provider (YAML spec) to custom_llm_provider (config)
|
|
config['custom_llm_provider'] = requester_manifest.spec['litellm_provider']
|
|
requester_inst = litellmchat.LiteLLMRequester(
|
|
ap=self.ap,
|
|
config=config,
|
|
)
|
|
self.ap.logger.debug(
|
|
f'Using LiteLLMRequester for {provider_entity.requester} '
|
|
f'with custom_llm_provider={config["custom_llm_provider"]}'
|
|
)
|
|
else:
|
|
# Use original requester class (for backward compatibility)
|
|
if provider_entity.requester not in self.requester_dict:
|
|
raise provider_errors.RequesterNotFoundError(provider_entity.requester)
|
|
requester_inst = self.requester_dict[provider_entity.requester](
|
|
ap=self.ap,
|
|
config=config,
|
|
)
|
|
|
|
await requester_inst.initialize()
|
|
|
|
token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or [])
|
|
|
|
provider = requester.RuntimeProvider(
|
|
provider_entity=provider_entity,
|
|
token_mgr=token_mgr,
|
|
requester=requester_inst,
|
|
)
|
|
return provider
|
|
|
|
async def remove_provider(self, provider_uuid: str):
|
|
"""Remove provider
|
|
|
|
This method will not consider the models using this provider,
|
|
because the models should be removed by the caller.
|
|
"""
|
|
del self.provider_dict[provider_uuid]
|
|
|
|
async def reload_provider(self, provider_uuid: str):
|
|
"""Reload provider"""
|
|
provider_entity = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
|
persistence_model.ModelProvider.uuid == provider_uuid
|
|
)
|
|
)
|
|
provider_entity = provider_entity.first()
|
|
if provider_entity is None:
|
|
raise provider_errors.ProviderNotFoundError(provider_uuid)
|
|
|
|
new_runtime_provider = await self.load_provider(provider_entity)
|
|
|
|
# update refs in runtime models
|
|
for model in self.llm_models:
|
|
if model.provider.provider_entity.uuid == provider_uuid:
|
|
model.provider = new_runtime_provider
|
|
for model in self.embedding_models:
|
|
if model.provider.provider_entity.uuid == provider_uuid:
|
|
model.provider = new_runtime_provider
|
|
for model in self.rerank_models:
|
|
if model.provider.provider_entity.uuid == provider_uuid:
|
|
model.provider = new_runtime_provider
|
|
|
|
# update ref in provider dict
|
|
self.provider_dict[provider_uuid] = new_runtime_provider
|
|
|
|
async def load_llm_model_with_provider(
|
|
self,
|
|
model_info: persistence_model.LLMModel | sqlalchemy.Row,
|
|
provider: requester.RuntimeProvider,
|
|
) -> requester.RuntimeLLMModel:
|
|
"""Load LLM model with provider info"""
|
|
if isinstance(model_info, sqlalchemy.Row):
|
|
model_info = persistence_model.LLMModel(**model_info._mapping)
|
|
|
|
runtime_llm_model = requester.RuntimeLLMModel(
|
|
model_entity=model_info,
|
|
provider=provider,
|
|
)
|
|
|
|
return runtime_llm_model
|
|
|
|
async def load_embedding_model_with_provider(
|
|
self,
|
|
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row,
|
|
provider: requester.RuntimeProvider,
|
|
) -> requester.RuntimeEmbeddingModel:
|
|
"""Load embedding model with provider info"""
|
|
if isinstance(model_info, sqlalchemy.Row):
|
|
model_info = persistence_model.EmbeddingModel(**model_info._mapping)
|
|
|
|
runtime_embedding_model = requester.RuntimeEmbeddingModel(
|
|
model_entity=model_info,
|
|
provider=provider,
|
|
)
|
|
|
|
return runtime_embedding_model
|
|
|
|
async def load_rerank_model_with_provider(
|
|
self,
|
|
model_info: persistence_model.RerankModel | sqlalchemy.Row,
|
|
provider: requester.RuntimeProvider,
|
|
) -> requester.RuntimeRerankModel:
|
|
"""Load rerank model with provider info"""
|
|
if isinstance(model_info, sqlalchemy.Row):
|
|
model_info = persistence_model.RerankModel(**model_info._mapping)
|
|
|
|
runtime_rerank_model = requester.RuntimeRerankModel(
|
|
model_entity=model_info,
|
|
provider=provider,
|
|
)
|
|
|
|
return runtime_rerank_model
|
|
|
|
async def load_llm_model(self, model_info: dict):
|
|
"""Load LLM model from dict (with provider info)"""
|
|
provider_info = model_info.get('provider', {})
|
|
if not provider_info:
|
|
raise ValueError('Provider info is required')
|
|
|
|
model_entity = persistence_model.LLMModel(
|
|
uuid=model_info.get('uuid', ''),
|
|
name=model_info.get('name', ''),
|
|
provider_uuid=model_info.get('provider_uuid', ''),
|
|
abilities=model_info.get('abilities', []),
|
|
extra_args=model_info.get('extra_args', {}),
|
|
)
|
|
|
|
provider_entity = persistence_model.ModelProvider(
|
|
uuid=provider_info.get('uuid', ''),
|
|
name=provider_info.get('name', ''),
|
|
requester=provider_info.get('requester', ''),
|
|
base_url=provider_info.get('base_url', ''),
|
|
api_keys=provider_info.get('api_keys', []),
|
|
)
|
|
|
|
await self.load_llm_model_with_provider(model_entity, provider_entity)
|
|
|
|
async def load_embedding_model(self, model_info: dict):
|
|
"""Load embedding model from dict (with provider info)"""
|
|
provider_info = model_info.get('provider', {})
|
|
if not provider_info:
|
|
raise ValueError('Provider info is required')
|
|
|
|
model_entity = persistence_model.EmbeddingModel(
|
|
uuid=model_info.get('uuid', ''),
|
|
name=model_info.get('name', ''),
|
|
provider_uuid=model_info.get('provider_uuid', ''),
|
|
extra_args=model_info.get('extra_args', {}),
|
|
)
|
|
|
|
provider_entity = persistence_model.ModelProvider(
|
|
uuid=provider_info.get('uuid', ''),
|
|
name=provider_info.get('name', ''),
|
|
requester=provider_info.get('requester', ''),
|
|
base_url=provider_info.get('base_url', ''),
|
|
api_keys=provider_info.get('api_keys', []),
|
|
)
|
|
|
|
await self.load_embedding_model_with_provider(model_entity, provider_entity)
|
|
|
|
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
|
|
"""Get LLM model by uuid"""
|
|
for model in self.llm_models:
|
|
if model.model_entity.uuid == uuid:
|
|
return model
|
|
raise ValueError(f'LLM model {uuid} not found')
|
|
|
|
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
|
|
"""Get embedding model by uuid"""
|
|
for model in self.embedding_models:
|
|
if model.model_entity.uuid == uuid:
|
|
return model
|
|
raise ValueError(f'Embedding model {uuid} not found')
|
|
|
|
async def get_rerank_model_by_uuid(self, uuid: str) -> requester.RuntimeRerankModel:
|
|
"""Get rerank model by uuid"""
|
|
for model in self.rerank_models:
|
|
if model.model_entity.uuid == uuid:
|
|
return model
|
|
raise ValueError(f'Rerank model {uuid} not found')
|
|
|
|
async def remove_llm_model(self, model_uuid: str):
|
|
"""Remove LLM model"""
|
|
for model in self.llm_models:
|
|
if model.model_entity.uuid == model_uuid:
|
|
self.llm_models.remove(model)
|
|
return
|
|
|
|
async def remove_embedding_model(self, model_uuid: str):
|
|
"""Remove embedding model"""
|
|
for model in self.embedding_models:
|
|
if model.model_entity.uuid == model_uuid:
|
|
self.embedding_models.remove(model)
|
|
return
|
|
|
|
async def remove_rerank_model(self, model_uuid: str):
|
|
"""Remove rerank model"""
|
|
for model in self.rerank_models:
|
|
if model.model_entity.uuid == model_uuid:
|
|
self.rerank_models.remove(model)
|
|
return
|
|
|
|
def get_available_requesters_info(self, model_type: str) -> list[dict]:
|
|
"""Get all available requesters"""
|
|
if model_type != '':
|
|
return [
|
|
component.to_plain_dict()
|
|
for component in self.requester_components
|
|
if model_type in component.spec['support_type']
|
|
]
|
|
else:
|
|
return [component.to_plain_dict() for component in self.requester_components]
|
|
|
|
def get_available_requester_info_by_name(self, name: str) -> dict | None:
|
|
"""Get requester info by name"""
|
|
for component in self.requester_components:
|
|
if component.metadata.name == name:
|
|
return component.to_plain_dict()
|
|
return None
|
|
|
|
def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None:
|
|
"""Get requester manifest by name"""
|
|
for component in self.requester_components:
|
|
if component.metadata.name == name:
|
|
return component
|
|
return None
|