mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-14 09:46:03 +00:00
* refactor(provider): use LiteLLM as unified LLM requester backend
- Replace 23+ individual requester implementations with unified litellmchat.py
- Add litellm_provider field to 27 YAML manifests for provider routing
- Delete redundant requester subclasses
- Add unit tests for LiteLLMRequester (29 tests)
- Fix num_retries parameter name (was max_retries)
- Fix exception handling order for subclass exceptions
LiteLLM provides unified API for 100+ providers, eliminating need for
provider-specific requesters.
* fix: ruff format provider.py
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
* refactor(provider): simplify LiteLLM requester usage handling
- Remove unused Anthropic-specific tool schema generation
- Share completion argument construction between normal and streaming calls
- Use LiteLLM/OpenAI native usage fields for monitoring
- Collect stream token usage from LiteLLM stream_options
- Update LiteLLM requester tests for unified usage fields
* restore: restore deleted provider requester files
Restore individual provider requester implementations that were
removed in de61b5d3. These files coexist with the unified
litellmchat.py backend.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
* feat: update requesters and improve provider selection UI
- Added `litellm_provider` field to various requesters' YAML configurations.
- Removed obsolete Python requester files for OpenRouter, PPIO, QHAIGC, ShengSuanYun, SiliconFlow, Space, TokenPony, VolcArk, and Xai.
- Introduced new requesters for Tencent and Together AI with corresponding YAML configurations and SVG icons.
- Enhanced the ProviderForm component to include a searchable dropdown for selecting providers, improving user experience.
- Updated localization files to include search provider text for both English and Chinese.
* fix(provider): align litellm rebase with master
* fix(provider): capture streaming token usage; add token observability
The LiteLLM streaming requester only captured usage when a chunk had an
empty `choices` list. Many OpenAI-compatible gateways (e.g. new-api) and
providers send the final usage payload in a chunk that still carries an
empty-delta choice, so streamed calls always recorded 0 tokens in the
monitoring logs/dashboard (non-streaming worked).
- Capture stream usage whenever a chunk carries it, regardless of choices
- Add robust _normalize_usage (dict/obj shapes, derive missing total_tokens)
- Register litellm in bootutils/deps.py (was in pyproject only)
- Add MonitoringService.get_token_statistics + /monitoring/token-statistics
endpoint: summary, per-model breakdown, token timeseries, and a
zero-token-success data-quality signal
- Add TokenMonitoring dashboard tab (summary tiles, stacked token chart,
per-model table) + i18n (en/zh)
- Regression tests for stream usage capture and usage normalization
Verified end-to-end against a real OpenAI-compatible endpoint with
gpt-5.5 and claude-opus-4-8: tokens now recorded non-zero for both
streaming and non-streaming paths.
* refactor(provider): simplify litellm capabilities
* style: simplify wrapped expressions
* feat(models): persist context metadata
* fix(provider): handle dict embeddings and openai-compatible rerank in LiteLLMRequester
- invoke_embedding: support both object- and dict-shaped response.data
entries (OpenAI-compatible gateways like new-api return dicts)
- invoke_rerank: litellm.arerank rejects the 'openai' provider, so for
openai-compatible (or unspecified) providers call the standard
Jina/Cohere-style POST /v1/rerank endpoint directly over HTTP
- accept both 'relevance_score' and 'score' fields in rerank results
- add unit tests for the openai-compatible HTTP rerank path
* feat(provider): enforce requester support_type when adding models
- frontend: AddModelPopover only shows model-type tabs (llm/embedding/
rerank) that the provider's requester declares in its manifest
support_type; ModelsDialog fetches requester manifests and maps
requester -> support_type, passed down through ProviderCard
- backend: add _validate_provider_supports guard in create_llm_model /
create_embedding_model / create_rerank_model so a model cannot be
attached to a provider whose requester does not support that type,
even if the frontend restriction is bypassed (manifests without
support_type are allowed for backward compatibility)
- manifests: correct support_type for providers that do not offer all
three model types:
- llm only: anthropic, deepseek, groq, moonshot, openrouter, xai
- llm + text-embedding: openai, gemini, mistral
- add rerank to new-api (verified working via /v1/rerank)
- set llm + text-embedding + rerank for aggregator/unknown gateways
* feat(provider): add searchable alias to requester manifests
- add a free-text 'alias' field to every requester manifest spec,
containing the vendor's English/Chinese names, pinyin, common
nicknames and flagship model-series names (e.g. moonshot -> kimi,
月之暗面; zhipu -> glm, 智谱清言)
- frontend: ProviderForm requester search now also matches against
alias (substring/contains), so searching 'kimi' surfaces Moonshot,
'硅基' surfaces SiliconFlow, etc.
- also fix support_type: openrouter (relay) supports embedding+rerank;
LangBot Space gains rerank (coming soon)
* fix(provider): make support_type guard defensive against incomplete model_mgr
- _validate_provider_supports now uses getattr to gracefully skip when
model_mgr / provider_dict / manifest lookup is unavailable, instead of
raising AttributeError (fixes unit tests that mock ap.model_mgr as a
bare SimpleNamespace)
- add TestValidateProviderSupports covering: allow supported type,
reject unsupported type, allow when support_type missing, allow when
provider unknown, degrade safely when model_mgr is incomplete
* fix(persistence): guard 0004 migration against missing llm_models table
The 0004_add_llm_model_context_length migration called
inspector.get_columns('llm_models') unconditionally, raising
NoSuchTableError when the table does not exist (e.g. migrating a
fresh/empty DB, as exercised by the integration tests where
create_all() registers no tables because the ORM models are not
imported). Every other migration guards with a table-existence check
first; add the same guard here for both upgrade and downgrade.
Also restore the test head assertion to 0004 (it had been lowered to
0003 to mask this failure).
* Merge branch 'master' into feat/litellm
Resolve conflicts:
- uv.lock: regenerated via 'uv lock' to reconcile litellm/fastuuid
(ours) with openai bump (master).
- Alembic migrations: master added 0004_add_mcp_readme while this
branch added 0004_add_llm_model_context_length, both as children of
0003 (would create multiple heads). Re-chain the litellm migration as
0005_add_llm_model_context_length with down_revision=0004_add_mcp_readme
for a single linear head. Update test head assertion accordingly.
* fix(persistence): shorten migration revision id to fit varchar(32)
PostgreSQL stores alembic_version.version_num as varchar(32).
'0005_add_llm_model_context_length' (33 chars) overflowed it, raising
StringDataRightTruncationError in the PG migration tests. Rename the
revision (and file) to '0005_add_llm_context_length' (27 chars) and
update the head assertions in both SQLite and PostgreSQL migration
tests.
---------
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: fdc310 <2213070223@qq.com>
Co-authored-by: RockChinQ <rockchinq@gmail.com>
586 lines
24 KiB
Python
586 lines
24 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
|
|
import sqlalchemy
|
|
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
|
|
|
from ....core import app
|
|
from ....entity.persistence import model as persistence_model
|
|
from ....entity.persistence import pipeline as persistence_pipeline
|
|
from ....provider.modelmgr import requester as model_requester
|
|
|
|
|
|
def _parse_provider_api_keys(provider_dict: dict) -> dict:
|
|
"""Parse api_keys if it's a JSON string"""
|
|
if isinstance(provider_dict.get('api_keys'), str):
|
|
import json
|
|
|
|
try:
|
|
provider_dict['api_keys'] = json.loads(provider_dict['api_keys'])
|
|
except Exception:
|
|
provider_dict['api_keys'] = []
|
|
return provider_dict
|
|
|
|
|
|
def _runtime_model_data(model_uuid: str, model_data: dict) -> dict:
|
|
"""Return model data for rebuilding runtime models after an update.
|
|
|
|
Update payloads intentionally omit uuid before writing to the database.
|
|
Runtime model entities still need the stable uuid so pipeline configs can
|
|
resolve the in-memory model immediately after an edit, without requiring a
|
|
process restart.
|
|
"""
|
|
return {**model_data, 'uuid': model_uuid}
|
|
|
|
|
|
async def _validate_provider_supports(ap: app.Application, provider_uuid: str, model_type: str) -> None:
|
|
"""Validate that the provider's requester declares support for ``model_type``.
|
|
|
|
``model_type`` is one of the manifest ``support_type`` values:
|
|
'llm', 'text-embedding', 'rerank'. Raises ValueError when the requester
|
|
manifest does not list the requested type. This is a server-side guard so
|
|
a model cannot be attached to a provider that does not support it, even if
|
|
the frontend tab restriction is bypassed.
|
|
"""
|
|
model_mgr = getattr(ap, 'model_mgr', None)
|
|
if model_mgr is None:
|
|
return
|
|
|
|
provider_dict = getattr(model_mgr, 'provider_dict', None)
|
|
if not provider_dict:
|
|
return
|
|
runtime_provider = provider_dict.get(provider_uuid)
|
|
if runtime_provider is None:
|
|
return
|
|
|
|
requester_name = getattr(getattr(runtime_provider, 'provider_entity', None), 'requester', None)
|
|
if not requester_name:
|
|
return
|
|
|
|
get_manifest = getattr(model_mgr, 'get_available_requester_manifest_by_name', None)
|
|
if not callable(get_manifest):
|
|
return
|
|
manifest = get_manifest(requester_name)
|
|
if manifest is None:
|
|
return
|
|
|
|
spec = getattr(manifest, 'spec', None) or {}
|
|
support_type = spec.get('support_type') if isinstance(spec, dict) else None
|
|
# When a manifest omits support_type, do not block (backward compatible).
|
|
if not support_type:
|
|
return
|
|
if model_type not in support_type:
|
|
raise ValueError(f'Provider requester "{requester_name}" does not support {model_type} models')
|
|
|
|
|
|
class LLMModelsService:
|
|
ap: app.Application
|
|
|
|
def __init__(self, ap: app.Application) -> None:
|
|
self.ap = ap
|
|
|
|
async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
|
|
"""Get all LLM models with provider info"""
|
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
|
models = result.all()
|
|
|
|
# Get all providers for lookup
|
|
providers_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider)
|
|
)
|
|
providers = {p.uuid: p for p in providers_result.all()}
|
|
|
|
models_list = []
|
|
for model in models:
|
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
|
provider = providers.get(model.provider_uuid)
|
|
if provider:
|
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
|
provider_dict = _parse_provider_api_keys(provider_dict)
|
|
if not include_secret:
|
|
provider_dict['api_keys'] = ['***'] * len(provider_dict.get('api_keys', []))
|
|
model_dict['provider'] = provider_dict
|
|
models_list.append(model_dict)
|
|
|
|
return models_list
|
|
|
|
async def get_llm_models_by_provider(self, provider_uuid: str) -> list[dict]:
|
|
"""Get LLM models by provider UUID"""
|
|
result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.LLMModel).where(
|
|
persistence_model.LLMModel.provider_uuid == provider_uuid
|
|
)
|
|
)
|
|
models = result.all()
|
|
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, m) for m in models]
|
|
|
|
async def create_llm_model(
|
|
self, model_data: dict, preserve_uuid: bool = False, auto_set_to_default_pipeline: bool = True
|
|
) -> str:
|
|
"""Create a new LLM model"""
|
|
if not preserve_uuid:
|
|
model_data['uuid'] = str(uuid.uuid4())
|
|
|
|
# Handle provider creation if needed
|
|
if 'provider' in model_data:
|
|
provider_data = model_data.pop('provider')
|
|
if provider_data.get('uuid'):
|
|
model_data['provider_uuid'] = provider_data['uuid']
|
|
else:
|
|
# Create new provider
|
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
|
requester=provider_data.get('requester', ''),
|
|
base_url=provider_data.get('base_url', ''),
|
|
api_keys=provider_data.get('api_keys', []),
|
|
)
|
|
model_data['provider_uuid'] = provider_uuid
|
|
|
|
await _validate_provider_supports(self.ap, model_data['provider_uuid'], 'llm')
|
|
|
|
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data))
|
|
|
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
|
if runtime_provider is None:
|
|
raise Exception('provider not found')
|
|
|
|
runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider(
|
|
persistence_model.LLMModel(**model_data),
|
|
runtime_provider,
|
|
)
|
|
self.ap.model_mgr.llm_models.append(runtime_llm_model)
|
|
|
|
if auto_set_to_default_pipeline:
|
|
# set the default pipeline model to this model
|
|
result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
|
persistence_pipeline.LegacyPipeline.is_default == True
|
|
)
|
|
)
|
|
pipeline = result.first()
|
|
if pipeline is not None:
|
|
model_config = pipeline.config.get('ai', {}).get('local-agent', {}).get('model', {})
|
|
if not model_config.get('primary', ''):
|
|
pipeline_config = pipeline.config
|
|
pipeline_config['ai']['local-agent']['model'] = {
|
|
'primary': model_data['uuid'],
|
|
'fallbacks': [],
|
|
}
|
|
pipeline_data = {'config': pipeline_config}
|
|
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
|
|
|
return model_data['uuid']
|
|
|
|
async def get_llm_model(self, model_uuid: str) -> dict | None:
|
|
"""Get a single LLM model with provider info"""
|
|
result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
|
)
|
|
model = result.first()
|
|
if model is None:
|
|
return None
|
|
|
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
|
|
|
# Get provider
|
|
provider_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
|
persistence_model.ModelProvider.uuid == model.provider_uuid
|
|
)
|
|
)
|
|
provider = provider_result.first()
|
|
if provider:
|
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
|
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
|
|
|
return model_dict
|
|
|
|
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
|
|
"""Update an existing LLM model"""
|
|
if 'uuid' in model_data:
|
|
del model_data['uuid']
|
|
|
|
# Handle provider update if needed
|
|
if 'provider' in model_data:
|
|
provider_data = model_data.pop('provider')
|
|
if provider_data.get('uuid'):
|
|
model_data['provider_uuid'] = provider_data['uuid']
|
|
else:
|
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
|
requester=provider_data.get('requester', ''),
|
|
base_url=provider_data.get('base_url', ''),
|
|
api_keys=provider_data.get('api_keys', []),
|
|
)
|
|
model_data['provider_uuid'] = provider_uuid
|
|
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.update(persistence_model.LLMModel)
|
|
.where(persistence_model.LLMModel.uuid == model_uuid)
|
|
.values(**model_data)
|
|
)
|
|
|
|
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
|
|
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
|
if runtime_provider is None:
|
|
raise Exception('provider not found')
|
|
|
|
runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider(
|
|
persistence_model.LLMModel(**_runtime_model_data(model_uuid, model_data)),
|
|
runtime_provider,
|
|
)
|
|
self.ap.model_mgr.llm_models.append(runtime_llm_model)
|
|
|
|
async def delete_llm_model(self, model_uuid: str) -> None:
|
|
"""Delete an LLM model"""
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
|
)
|
|
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
|
|
|
async def test_llm_model(self, model_uuid: str, model_data: dict) -> None:
|
|
"""Test an LLM model"""
|
|
runtime_llm_model: model_requester.RuntimeLLMModel | None = None
|
|
|
|
if model_uuid != '_':
|
|
for model in self.ap.model_mgr.llm_models:
|
|
if model.model_entity.uuid == model_uuid:
|
|
runtime_llm_model = model
|
|
break
|
|
if runtime_llm_model is None:
|
|
raise Exception('model not found')
|
|
else:
|
|
runtime_llm_model = await self.ap.model_mgr.init_temporary_runtime_llm_model(model_data)
|
|
|
|
extra_args = model_data.get('extra_args', {})
|
|
await runtime_llm_model.provider.invoke_llm(
|
|
query=None,
|
|
model=runtime_llm_model,
|
|
messages=[provider_message.Message(role='user', content='Hello, world! Please just reply a "Hello".')],
|
|
funcs=[],
|
|
extra_args=extra_args,
|
|
)
|
|
|
|
|
|
class EmbeddingModelsService:
|
|
ap: app.Application
|
|
|
|
def __init__(self, ap: app.Application) -> None:
|
|
self.ap = ap
|
|
|
|
async def get_embedding_models(self) -> list[dict]:
|
|
"""Get all embedding models with provider info"""
|
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
|
|
models = result.all()
|
|
|
|
providers_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider)
|
|
)
|
|
providers = {p.uuid: p for p in providers_result.all()}
|
|
|
|
models_list = []
|
|
for model in models:
|
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
|
|
provider = providers.get(model.provider_uuid)
|
|
if provider:
|
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
|
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
|
models_list.append(model_dict)
|
|
|
|
return models_list
|
|
|
|
async def get_embedding_models_by_provider(self, provider_uuid: str) -> list[dict]:
|
|
"""Get embedding models by provider UUID"""
|
|
result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.EmbeddingModel).where(
|
|
persistence_model.EmbeddingModel.provider_uuid == provider_uuid
|
|
)
|
|
)
|
|
models = result.all()
|
|
return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, m) for m in models]
|
|
|
|
async def create_embedding_model(self, model_data: dict, preserve_uuid: bool = False) -> str:
|
|
"""Create a new embedding model"""
|
|
if not preserve_uuid:
|
|
model_data['uuid'] = str(uuid.uuid4())
|
|
|
|
if 'provider' in model_data:
|
|
provider_data = model_data.pop('provider')
|
|
if provider_data.get('uuid'):
|
|
model_data['provider_uuid'] = provider_data['uuid']
|
|
else:
|
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
|
requester=provider_data.get('requester', ''),
|
|
base_url=provider_data.get('base_url', ''),
|
|
api_keys=provider_data.get('api_keys', []),
|
|
)
|
|
model_data['provider_uuid'] = provider_uuid
|
|
|
|
await _validate_provider_supports(self.ap, model_data['provider_uuid'], 'text-embedding')
|
|
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data)
|
|
)
|
|
|
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
|
if runtime_provider is None:
|
|
raise Exception('provider not found')
|
|
|
|
runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider(
|
|
persistence_model.EmbeddingModel(**model_data),
|
|
runtime_provider,
|
|
)
|
|
self.ap.model_mgr.embedding_models.append(runtime_embedding_model)
|
|
|
|
return model_data['uuid']
|
|
|
|
async def get_embedding_model(self, model_uuid: str) -> dict | None:
|
|
"""Get a single embedding model with provider info"""
|
|
result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.EmbeddingModel).where(
|
|
persistence_model.EmbeddingModel.uuid == model_uuid
|
|
)
|
|
)
|
|
model = result.first()
|
|
if model is None:
|
|
return None
|
|
|
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model)
|
|
|
|
provider_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
|
persistence_model.ModelProvider.uuid == model.provider_uuid
|
|
)
|
|
)
|
|
provider = provider_result.first()
|
|
if provider:
|
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
|
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
|
|
|
return model_dict
|
|
|
|
async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None:
|
|
"""Update an existing embedding model"""
|
|
if 'uuid' in model_data:
|
|
del model_data['uuid']
|
|
|
|
if 'provider' in model_data:
|
|
provider_data = model_data.pop('provider')
|
|
if provider_data.get('uuid'):
|
|
model_data['provider_uuid'] = provider_data['uuid']
|
|
else:
|
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
|
requester=provider_data.get('requester', ''),
|
|
base_url=provider_data.get('base_url', ''),
|
|
api_keys=provider_data.get('api_keys', []),
|
|
)
|
|
model_data['provider_uuid'] = provider_uuid
|
|
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.update(persistence_model.EmbeddingModel)
|
|
.where(persistence_model.EmbeddingModel.uuid == model_uuid)
|
|
.values(**model_data)
|
|
)
|
|
|
|
await self.ap.model_mgr.remove_embedding_model(model_uuid)
|
|
|
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
|
if runtime_provider is None:
|
|
raise Exception('provider not found')
|
|
|
|
runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider(
|
|
persistence_model.EmbeddingModel(**_runtime_model_data(model_uuid, model_data)),
|
|
runtime_provider,
|
|
)
|
|
self.ap.model_mgr.embedding_models.append(runtime_embedding_model)
|
|
|
|
async def delete_embedding_model(self, model_uuid: str) -> None:
|
|
"""Delete an embedding model"""
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.delete(persistence_model.EmbeddingModel).where(
|
|
persistence_model.EmbeddingModel.uuid == model_uuid
|
|
)
|
|
)
|
|
await self.ap.model_mgr.remove_embedding_model(model_uuid)
|
|
|
|
async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None:
|
|
"""Test an embedding model"""
|
|
runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None
|
|
|
|
if model_uuid != '_':
|
|
for model in self.ap.model_mgr.embedding_models:
|
|
if model.model_entity.uuid == model_uuid:
|
|
runtime_embedding_model = model
|
|
break
|
|
if runtime_embedding_model is None:
|
|
raise Exception('model not found')
|
|
else:
|
|
runtime_embedding_model = await self.ap.model_mgr.init_temporary_runtime_embedding_model(model_data)
|
|
|
|
await runtime_embedding_model.provider.invoke_embedding(
|
|
model=runtime_embedding_model,
|
|
input_text=['Hello, world!'],
|
|
extra_args={},
|
|
)
|
|
|
|
|
|
class RerankModelsService:
|
|
ap: app.Application
|
|
|
|
def __init__(self, ap: app.Application) -> None:
|
|
self.ap = ap
|
|
|
|
async def get_rerank_models(self) -> list[dict]:
|
|
"""Get all rerank models with provider info"""
|
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel))
|
|
models = result.all()
|
|
|
|
providers_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider)
|
|
)
|
|
providers = {p.uuid: p for p in providers_result.all()}
|
|
|
|
models_list = []
|
|
for model in models:
|
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
|
provider = providers.get(model.provider_uuid)
|
|
if provider:
|
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
|
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
|
models_list.append(model_dict)
|
|
|
|
return models_list
|
|
|
|
async def get_rerank_models_by_provider(self, provider_uuid: str) -> list[dict]:
|
|
"""Get rerank models by provider UUID"""
|
|
result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.RerankModel).where(
|
|
persistence_model.RerankModel.provider_uuid == provider_uuid
|
|
)
|
|
)
|
|
models = result.all()
|
|
return [self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, m) for m in models]
|
|
|
|
async def create_rerank_model(self, model_data: dict, preserve_uuid: bool = False) -> str:
|
|
"""Create a new rerank model"""
|
|
if not preserve_uuid:
|
|
model_data['uuid'] = str(uuid.uuid4())
|
|
|
|
if 'provider' in model_data:
|
|
provider_data = model_data.pop('provider')
|
|
if provider_data.get('uuid'):
|
|
model_data['provider_uuid'] = provider_data['uuid']
|
|
else:
|
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
|
requester=provider_data.get('requester', ''),
|
|
base_url=provider_data.get('base_url', ''),
|
|
api_keys=provider_data.get('api_keys', []),
|
|
)
|
|
model_data['provider_uuid'] = provider_uuid
|
|
|
|
await _validate_provider_supports(self.ap, model_data['provider_uuid'], 'rerank')
|
|
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.insert(persistence_model.RerankModel).values(**model_data)
|
|
)
|
|
|
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
|
if runtime_provider is None:
|
|
raise Exception('provider not found')
|
|
|
|
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
|
persistence_model.RerankModel(**model_data),
|
|
runtime_provider,
|
|
)
|
|
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
|
|
|
return model_data['uuid']
|
|
|
|
async def get_rerank_model(self, model_uuid: str) -> dict | None:
|
|
"""Get a single rerank model with provider info"""
|
|
result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
|
)
|
|
model = result.first()
|
|
if model is None:
|
|
return None
|
|
|
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
|
|
|
provider_result = await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
|
persistence_model.ModelProvider.uuid == model.provider_uuid
|
|
)
|
|
)
|
|
provider = provider_result.first()
|
|
if provider:
|
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
|
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
|
|
|
return model_dict
|
|
|
|
async def update_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
|
"""Update an existing rerank model"""
|
|
if 'uuid' in model_data:
|
|
del model_data['uuid']
|
|
|
|
if 'provider' in model_data:
|
|
provider_data = model_data.pop('provider')
|
|
if provider_data.get('uuid'):
|
|
model_data['provider_uuid'] = provider_data['uuid']
|
|
else:
|
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
|
requester=provider_data.get('requester', ''),
|
|
base_url=provider_data.get('base_url', ''),
|
|
api_keys=provider_data.get('api_keys', []),
|
|
)
|
|
model_data['provider_uuid'] = provider_uuid
|
|
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.update(persistence_model.RerankModel)
|
|
.where(persistence_model.RerankModel.uuid == model_uuid)
|
|
.values(**model_data)
|
|
)
|
|
|
|
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
|
|
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
|
if runtime_provider is None:
|
|
raise Exception('provider not found')
|
|
|
|
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
|
persistence_model.RerankModel(**_runtime_model_data(model_uuid, model_data)),
|
|
runtime_provider,
|
|
)
|
|
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
|
|
|
async def delete_rerank_model(self, model_uuid: str) -> None:
|
|
"""Delete a rerank model"""
|
|
await self.ap.persistence_mgr.execute_async(
|
|
sqlalchemy.delete(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
|
)
|
|
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
|
|
|
async def test_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
|
"""Test a rerank model"""
|
|
runtime_rerank_model: model_requester.RuntimeRerankModel | None = None
|
|
|
|
if model_uuid != '_':
|
|
for model in self.ap.model_mgr.rerank_models:
|
|
if model.model_entity.uuid == model_uuid:
|
|
runtime_rerank_model = model
|
|
break
|
|
if runtime_rerank_model is None:
|
|
raise Exception('model not found')
|
|
else:
|
|
runtime_rerank_model = await self.ap.model_mgr.init_temporary_runtime_rerank_model(model_data)
|
|
|
|
await runtime_rerank_model.provider.invoke_rerank(
|
|
model=runtime_rerank_model,
|
|
query='What is artificial intelligence?',
|
|
documents=[
|
|
'Artificial intelligence is a branch of computer science.',
|
|
'The weather is nice today.',
|
|
],
|
|
)
|