Feat/rerank model (#2137)
* feat(provider): add rerank model management as a core model type * feat(provider): add rerank support to existing requesters and new rerank providers * feat(web): add rerank model management UI and pipeline config * fix(provider): correct rerank support_type after verification - Add rerank to OpenRouter (confirmed /api/v1/rerank endpoint) - Remove rerank from Ollama (no native support, PR #7219 unmerged) - Remove rerank from JiekouAI (no rerank docs found, URL path mismatch) * fix(provider): remove alru_cache from model getters and add rerank param hints * fix: resolve lint errors - Remove unused alru_cache import from modelmgr.py - Remove unused error_message variable in invoke_rerank - Fix prettier formatting in frontend files Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix: remove unused exception variable - Change `except Exception as e:` to `except Exception:` since e is not used - Fix prettier formatting in ProviderCard.tsx Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix: apply ruff format Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * feat(template): add rerank config fields to default pipeline config Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * chore: remove PR.md Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(ui): remove duplicate rerank model form in AddModelPopover The form was being rendered twice: once in TabsContent manual mode and again in a separate conditional block for rerank tab. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
@@ -97,3 +97,51 @@ class EmbeddingModelsRouterGroup(group.RouterGroup):
|
|||||||
await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data)
|
await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data)
|
||||||
|
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
|
@group.group_class('models/rerank', '/api/v1/provider/models/rerank')
|
||||||
|
class RerankModelsRouterGroup(group.RouterGroup):
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
@self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||||
|
async def _() -> str:
|
||||||
|
if quart.request.method == 'GET':
|
||||||
|
provider_uuid = quart.request.args.get('provider_uuid')
|
||||||
|
if provider_uuid:
|
||||||
|
return self.success(
|
||||||
|
data={
|
||||||
|
'models': await self.ap.rerank_models_service.get_rerank_models_by_provider(provider_uuid)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self.success(data={'models': await self.ap.rerank_models_service.get_rerank_models()})
|
||||||
|
elif quart.request.method == 'POST':
|
||||||
|
json_data = await quart.request.json
|
||||||
|
model_uuid = await self.ap.rerank_models_service.create_rerank_model(json_data)
|
||||||
|
return self.success(data={'uuid': model_uuid})
|
||||||
|
|
||||||
|
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||||
|
async def _(model_uuid: str) -> str:
|
||||||
|
if quart.request.method == 'GET':
|
||||||
|
model = await self.ap.rerank_models_service.get_rerank_model(model_uuid)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
return self.http_status(404, -1, 'model not found')
|
||||||
|
|
||||||
|
return self.success(data={'model': model})
|
||||||
|
elif quart.request.method == 'PUT':
|
||||||
|
json_data = await quart.request.json
|
||||||
|
|
||||||
|
await self.ap.rerank_models_service.update_rerank_model(model_uuid, json_data)
|
||||||
|
|
||||||
|
return self.success()
|
||||||
|
elif quart.request.method == 'DELETE':
|
||||||
|
await self.ap.rerank_models_service.delete_rerank_model(model_uuid)
|
||||||
|
|
||||||
|
return self.success()
|
||||||
|
|
||||||
|
@self.route('/<model_uuid>/test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY)
|
||||||
|
async def _(model_uuid: str) -> str:
|
||||||
|
json_data = await quart.request.json
|
||||||
|
|
||||||
|
await self.ap.rerank_models_service.test_rerank_model(model_uuid, json_data)
|
||||||
|
|
||||||
|
return self.success()
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
|||||||
counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid'])
|
counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid'])
|
||||||
provider['llm_count'] = counts['llm_count']
|
provider['llm_count'] = counts['llm_count']
|
||||||
provider['embedding_count'] = counts['embedding_count']
|
provider['embedding_count'] = counts['embedding_count']
|
||||||
|
provider['rerank_count'] = counts['rerank_count']
|
||||||
return self.success(data={'providers': providers})
|
return self.success(data={'providers': providers})
|
||||||
elif quart.request.method == 'POST':
|
elif quart.request.method == 'POST':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
@@ -32,6 +33,7 @@ class ModelProvidersRouterGroup(group.RouterGroup):
|
|||||||
counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid)
|
counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid)
|
||||||
provider['llm_count'] = counts['llm_count']
|
provider['llm_count'] = counts['llm_count']
|
||||||
provider['embedding_count'] = counts['embedding_count']
|
provider['embedding_count'] = counts['embedding_count']
|
||||||
|
provider['rerank_count'] = counts['rerank_count']
|
||||||
return self.success(data={'provider': provider})
|
return self.success(data={'provider': provider})
|
||||||
elif quart.request.method == 'PUT':
|
elif quart.request.method == 'PUT':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
|
|||||||
@@ -367,3 +367,162 @@ class EmbeddingModelsService:
|
|||||||
input_text=['Hello, world!'],
|
input_text=['Hello, world!'],
|
||||||
extra_args={},
|
extra_args={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RerankModelsService:
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application) -> None:
|
||||||
|
self.ap = ap
|
||||||
|
|
||||||
|
async def get_rerank_models(self) -> list[dict]:
|
||||||
|
"""Get all rerank models with provider info"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel))
|
||||||
|
models = result.all()
|
||||||
|
|
||||||
|
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.ModelProvider)
|
||||||
|
)
|
||||||
|
providers = {p.uuid: p for p in providers_result.all()}
|
||||||
|
|
||||||
|
models_list = []
|
||||||
|
for model in models:
|
||||||
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
||||||
|
provider = providers.get(model.provider_uuid)
|
||||||
|
if provider:
|
||||||
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||||
|
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||||
|
models_list.append(model_dict)
|
||||||
|
|
||||||
|
return models_list
|
||||||
|
|
||||||
|
async def get_rerank_models_by_provider(self, provider_uuid: str) -> list[dict]:
|
||||||
|
"""Get rerank models by provider UUID"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.RerankModel).where(
|
||||||
|
persistence_model.RerankModel.provider_uuid == provider_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
models = result.all()
|
||||||
|
return [self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, m) for m in models]
|
||||||
|
|
||||||
|
async def create_rerank_model(self, model_data: dict, preserve_uuid: bool = False) -> str:
|
||||||
|
"""Create a new rerank model"""
|
||||||
|
if not preserve_uuid:
|
||||||
|
model_data['uuid'] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
if 'provider' in model_data:
|
||||||
|
provider_data = model_data.pop('provider')
|
||||||
|
if provider_data.get('uuid'):
|
||||||
|
model_data['provider_uuid'] = provider_data['uuid']
|
||||||
|
else:
|
||||||
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||||
|
requester=provider_data.get('requester', ''),
|
||||||
|
base_url=provider_data.get('base_url', ''),
|
||||||
|
api_keys=provider_data.get('api_keys', []),
|
||||||
|
)
|
||||||
|
model_data['provider_uuid'] = provider_uuid
|
||||||
|
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.insert(persistence_model.RerankModel).values(**model_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||||
|
if runtime_provider is None:
|
||||||
|
raise Exception('provider not found')
|
||||||
|
|
||||||
|
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
||||||
|
persistence_model.RerankModel(**model_data),
|
||||||
|
runtime_provider,
|
||||||
|
)
|
||||||
|
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
||||||
|
|
||||||
|
return model_data['uuid']
|
||||||
|
|
||||||
|
async def get_rerank_model(self, model_uuid: str) -> dict | None:
|
||||||
|
"""Get a single rerank model with provider info"""
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
||||||
|
)
|
||||||
|
model = result.first()
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model)
|
||||||
|
|
||||||
|
provider_result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||||
|
persistence_model.ModelProvider.uuid == model.provider_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider = provider_result.first()
|
||||||
|
if provider:
|
||||||
|
provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider)
|
||||||
|
model_dict['provider'] = _parse_provider_api_keys(provider_dict)
|
||||||
|
|
||||||
|
return model_dict
|
||||||
|
|
||||||
|
async def update_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
||||||
|
"""Update an existing rerank model"""
|
||||||
|
if 'uuid' in model_data:
|
||||||
|
del model_data['uuid']
|
||||||
|
|
||||||
|
if 'provider' in model_data:
|
||||||
|
provider_data = model_data.pop('provider')
|
||||||
|
if provider_data.get('uuid'):
|
||||||
|
model_data['provider_uuid'] = provider_data['uuid']
|
||||||
|
else:
|
||||||
|
provider_uuid = await self.ap.provider_service.find_or_create_provider(
|
||||||
|
requester=provider_data.get('requester', ''),
|
||||||
|
base_url=provider_data.get('base_url', ''),
|
||||||
|
api_keys=provider_data.get('api_keys', []),
|
||||||
|
)
|
||||||
|
model_data['provider_uuid'] = provider_uuid
|
||||||
|
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.update(persistence_model.RerankModel)
|
||||||
|
.where(persistence_model.RerankModel.uuid == model_uuid)
|
||||||
|
.values(**model_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
||||||
|
|
||||||
|
runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid'])
|
||||||
|
if runtime_provider is None:
|
||||||
|
raise Exception('provider not found')
|
||||||
|
|
||||||
|
runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider(
|
||||||
|
persistence_model.RerankModel(**model_data),
|
||||||
|
runtime_provider,
|
||||||
|
)
|
||||||
|
self.ap.model_mgr.rerank_models.append(runtime_rerank_model)
|
||||||
|
|
||||||
|
async def delete_rerank_model(self, model_uuid: str) -> None:
|
||||||
|
"""Delete a rerank model"""
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.delete(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid)
|
||||||
|
)
|
||||||
|
await self.ap.model_mgr.remove_rerank_model(model_uuid)
|
||||||
|
|
||||||
|
async def test_rerank_model(self, model_uuid: str, model_data: dict) -> None:
|
||||||
|
"""Test a rerank model"""
|
||||||
|
runtime_rerank_model: model_requester.RuntimeRerankModel | None = None
|
||||||
|
|
||||||
|
if model_uuid != '_':
|
||||||
|
for model in self.ap.model_mgr.rerank_models:
|
||||||
|
if model.model_entity.uuid == model_uuid:
|
||||||
|
runtime_rerank_model = model
|
||||||
|
break
|
||||||
|
if runtime_rerank_model is None:
|
||||||
|
raise Exception('model not found')
|
||||||
|
else:
|
||||||
|
runtime_rerank_model = await self.ap.model_mgr.init_temporary_runtime_rerank_model(model_data)
|
||||||
|
|
||||||
|
await runtime_rerank_model.provider.invoke_rerank(
|
||||||
|
model=runtime_rerank_model,
|
||||||
|
query='What is artificial intelligence?',
|
||||||
|
documents=[
|
||||||
|
'Artificial intelligence is a branch of computer science.',
|
||||||
|
'The weather is nice today.',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -98,6 +98,14 @@ class ModelProviderService:
|
|||||||
if embedding_result.first() is not None:
|
if embedding_result.first() is not None:
|
||||||
raise ValueError('Cannot delete provider: Embedding models still reference it')
|
raise ValueError('Cannot delete provider: Embedding models still reference it')
|
||||||
|
|
||||||
|
rerank_result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(persistence_model.RerankModel).where(
|
||||||
|
persistence_model.RerankModel.provider_uuid == provider_uuid
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if rerank_result.first() is not None:
|
||||||
|
raise ValueError('Cannot delete provider: Rerank models still reference it')
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.delete(persistence_model.ModelProvider).where(
|
sqlalchemy.delete(persistence_model.ModelProvider).where(
|
||||||
persistence_model.ModelProvider.uuid == provider_uuid
|
persistence_model.ModelProvider.uuid == provider_uuid
|
||||||
@@ -122,7 +130,14 @@ class ModelProviderService:
|
|||||||
)
|
)
|
||||||
embedding_count = embedding_result.scalar() or 0
|
embedding_count = embedding_result.scalar() or 0
|
||||||
|
|
||||||
return {'llm_count': llm_count, 'embedding_count': embedding_count}
|
rerank_result = await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.select(sqlalchemy.func.count())
|
||||||
|
.select_from(persistence_model.RerankModel)
|
||||||
|
.where(persistence_model.RerankModel.provider_uuid == provider_uuid)
|
||||||
|
)
|
||||||
|
rerank_count = rerank_result.scalar() or 0
|
||||||
|
|
||||||
|
return {'llm_count': llm_count, 'embedding_count': embedding_count, 'rerank_count': rerank_count}
|
||||||
|
|
||||||
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
||||||
"""Find existing provider or create new one"""
|
"""Find existing provider or create new one"""
|
||||||
|
|||||||
@@ -133,6 +133,8 @@ class Application:
|
|||||||
|
|
||||||
embedding_models_service: model_service.EmbeddingModelsService = None
|
embedding_models_service: model_service.EmbeddingModelsService = None
|
||||||
|
|
||||||
|
rerank_models_service: model_service.RerankModelsService = None
|
||||||
|
|
||||||
provider_service: provider_service.ModelProviderService = None
|
provider_service: provider_service.ModelProviderService = None
|
||||||
|
|
||||||
pipeline_service: pipeline_service.PipelineService = None
|
pipeline_service: pipeline_service.PipelineService = None
|
||||||
|
|||||||
@@ -61,6 +61,9 @@ class BuildAppStage(stage.BootingStage):
|
|||||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||||
ap.embedding_models_service = embedding_models_service_inst
|
ap.embedding_models_service = embedding_models_service_inst
|
||||||
|
|
||||||
|
rerank_models_service_inst = model_service.RerankModelsService(ap)
|
||||||
|
ap.rerank_models_service = rerank_models_service_inst
|
||||||
|
|
||||||
provider_service_inst = provider_service.ModelProviderService(ap)
|
provider_service_inst = provider_service.ModelProviderService(ap)
|
||||||
ap.provider_service = provider_service_inst
|
ap.provider_service = provider_service_inst
|
||||||
|
|
||||||
|
|||||||
@@ -59,3 +59,22 @@ class EmbeddingModel(Base):
|
|||||||
server_default=sqlalchemy.func.now(),
|
server_default=sqlalchemy.func.now(),
|
||||||
onupdate=sqlalchemy.func.now(),
|
onupdate=sqlalchemy.func.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RerankModel(Base):
|
||||||
|
"""Rerank model"""
|
||||||
|
|
||||||
|
__tablename__ = 'rerank_models'
|
||||||
|
|
||||||
|
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||||
|
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||||
|
provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||||
|
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||||
|
prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||||
|
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||||
|
updated_at = sqlalchemy.Column(
|
||||||
|
sqlalchemy.DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sqlalchemy.func.now(),
|
||||||
|
onupdate=sqlalchemy.func.now(),
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
"""add rerank_models table
|
||||||
|
|
||||||
|
Revision ID: 0003_add_rerank_models
|
||||||
|
Revises: 0002_sample
|
||||||
|
Create Date: 2026-04-19
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = '0003_add_rerank_models'
|
||||||
|
down_revision = '0002_sample'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'rerank_models',
|
||||||
|
sa.Column('uuid', sa.String(255), primary_key=True, unique=True),
|
||||||
|
sa.Column('name', sa.String(255), nullable=False),
|
||||||
|
sa.Column('provider_uuid', sa.String(255), nullable=False),
|
||||||
|
sa.Column('extra_args', sa.JSON, nullable=False, server_default='{}'),
|
||||||
|
sa.Column('prefered_ranking', sa.Integer, nullable=False, server_default='0'),
|
||||||
|
sa.Column('created_at', sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column('updated_at', sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('rerank_models')
|
||||||
@@ -9,7 +9,6 @@ from ...discover import engine
|
|||||||
from . import token
|
from . import token
|
||||||
from ...entity.persistence import model as persistence_model
|
from ...entity.persistence import model as persistence_model
|
||||||
from ...entity.errors import provider as provider_errors
|
from ...entity.errors import provider as provider_errors
|
||||||
from async_lru import alru_cache
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
@@ -24,6 +23,8 @@ class ModelManager:
|
|||||||
|
|
||||||
embedding_models: list[requester.RuntimeEmbeddingModel]
|
embedding_models: list[requester.RuntimeEmbeddingModel]
|
||||||
|
|
||||||
|
rerank_models: list[requester.RuntimeRerankModel]
|
||||||
|
|
||||||
requester_components: list[engine.Component]
|
requester_components: list[engine.Component]
|
||||||
|
|
||||||
requester_dict: dict[str, type[requester.ProviderAPIRequester]]
|
requester_dict: dict[str, type[requester.ProviderAPIRequester]]
|
||||||
@@ -32,6 +33,7 @@ class ModelManager:
|
|||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.llm_models = []
|
self.llm_models = []
|
||||||
self.embedding_models = []
|
self.embedding_models = []
|
||||||
|
self.rerank_models = []
|
||||||
self.requester_components = []
|
self.requester_components = []
|
||||||
self.requester_dict = {}
|
self.requester_dict = {}
|
||||||
|
|
||||||
@@ -64,8 +66,7 @@ class ModelManager:
|
|||||||
|
|
||||||
self.llm_models = []
|
self.llm_models = []
|
||||||
self.embedding_models = []
|
self.embedding_models = []
|
||||||
|
self.rerank_models = []
|
||||||
# Load all providers first
|
|
||||||
self.provider_dict = {}
|
self.provider_dict = {}
|
||||||
providers_result = await self.ap.persistence_mgr.execute_async(
|
providers_result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_model.ModelProvider)
|
sqlalchemy.select(persistence_model.ModelProvider)
|
||||||
@@ -110,6 +111,22 @@ class ModelManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
|
self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||||
|
|
||||||
|
# Load rerank models
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel))
|
||||||
|
rerank_models = result.all()
|
||||||
|
for rerank_model in rerank_models:
|
||||||
|
try:
|
||||||
|
provider = self.provider_dict.get(rerank_model.provider_uuid)
|
||||||
|
if provider is None:
|
||||||
|
self.ap.logger.warning(
|
||||||
|
f'Provider {rerank_model.provider_uuid} not found for model {rerank_model.uuid}'
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
runtime_rerank_model = await self.load_rerank_model_with_provider(rerank_model, provider)
|
||||||
|
self.rerank_models.append(runtime_rerank_model)
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f'Failed to load model {rerank_model.uuid}: {e}\n{traceback.format_exc()}')
|
||||||
|
|
||||||
async def sync_new_models_from_space(self):
|
async def sync_new_models_from_space(self):
|
||||||
"""Sync models from Space"""
|
"""Sync models from Space"""
|
||||||
space_model_provider = await self.ap.persistence_mgr.execute_async(
|
space_model_provider = await self.ap.persistence_mgr.execute_async(
|
||||||
@@ -212,6 +229,26 @@ class ModelManager:
|
|||||||
|
|
||||||
return runtime_embedding_model
|
return runtime_embedding_model
|
||||||
|
|
||||||
|
async def init_temporary_runtime_rerank_model(
|
||||||
|
self,
|
||||||
|
model_info: dict,
|
||||||
|
) -> requester.RuntimeRerankModel:
|
||||||
|
"""Initialize runtime rerank model from dict (for testing)"""
|
||||||
|
provider_info = model_info.get('provider', {})
|
||||||
|
runtime_provider = await self.load_provider(provider_info)
|
||||||
|
|
||||||
|
runtime_rerank_model = requester.RuntimeRerankModel(
|
||||||
|
model_entity=persistence_model.RerankModel(
|
||||||
|
uuid=model_info.get('uuid', ''),
|
||||||
|
name=model_info.get('name', ''),
|
||||||
|
provider_uuid='',
|
||||||
|
extra_args=model_info.get('extra_args', {}),
|
||||||
|
),
|
||||||
|
provider=runtime_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return runtime_rerank_model
|
||||||
|
|
||||||
async def load_provider(
|
async def load_provider(
|
||||||
self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict
|
self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict
|
||||||
) -> requester.RuntimeProvider:
|
) -> requester.RuntimeProvider:
|
||||||
@@ -269,6 +306,9 @@ class ModelManager:
|
|||||||
for model in self.embedding_models:
|
for model in self.embedding_models:
|
||||||
if model.provider.provider_entity.uuid == provider_uuid:
|
if model.provider.provider_entity.uuid == provider_uuid:
|
||||||
model.provider = new_runtime_provider
|
model.provider = new_runtime_provider
|
||||||
|
for model in self.rerank_models:
|
||||||
|
if model.provider.provider_entity.uuid == provider_uuid:
|
||||||
|
model.provider = new_runtime_provider
|
||||||
|
|
||||||
# update ref in provider dict
|
# update ref in provider dict
|
||||||
self.provider_dict[provider_uuid] = new_runtime_provider
|
self.provider_dict[provider_uuid] = new_runtime_provider
|
||||||
@@ -305,6 +345,22 @@ class ModelManager:
|
|||||||
|
|
||||||
return runtime_embedding_model
|
return runtime_embedding_model
|
||||||
|
|
||||||
|
async def load_rerank_model_with_provider(
|
||||||
|
self,
|
||||||
|
model_info: persistence_model.RerankModel | sqlalchemy.Row,
|
||||||
|
provider: requester.RuntimeProvider,
|
||||||
|
) -> requester.RuntimeRerankModel:
|
||||||
|
"""Load rerank model with provider info"""
|
||||||
|
if isinstance(model_info, sqlalchemy.Row):
|
||||||
|
model_info = persistence_model.RerankModel(**model_info._mapping)
|
||||||
|
|
||||||
|
runtime_rerank_model = requester.RuntimeRerankModel(
|
||||||
|
model_entity=model_info,
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return runtime_rerank_model
|
||||||
|
|
||||||
async def load_llm_model(self, model_info: dict):
|
async def load_llm_model(self, model_info: dict):
|
||||||
"""Load LLM model from dict (with provider info)"""
|
"""Load LLM model from dict (with provider info)"""
|
||||||
provider_info = model_info.get('provider', {})
|
provider_info = model_info.get('provider', {})
|
||||||
@@ -352,7 +408,6 @@ class ModelManager:
|
|||||||
|
|
||||||
await self.load_embedding_model_with_provider(model_entity, provider_entity)
|
await self.load_embedding_model_with_provider(model_entity, provider_entity)
|
||||||
|
|
||||||
@alru_cache(ttl=60 * 5)
|
|
||||||
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
|
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
|
||||||
"""Get LLM model by uuid"""
|
"""Get LLM model by uuid"""
|
||||||
for model in self.llm_models:
|
for model in self.llm_models:
|
||||||
@@ -360,7 +415,6 @@ class ModelManager:
|
|||||||
return model
|
return model
|
||||||
raise ValueError(f'LLM model {uuid} not found')
|
raise ValueError(f'LLM model {uuid} not found')
|
||||||
|
|
||||||
@alru_cache(ttl=60 * 5)
|
|
||||||
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
|
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
|
||||||
"""Get embedding model by uuid"""
|
"""Get embedding model by uuid"""
|
||||||
for model in self.embedding_models:
|
for model in self.embedding_models:
|
||||||
@@ -368,6 +422,13 @@ class ModelManager:
|
|||||||
return model
|
return model
|
||||||
raise ValueError(f'Embedding model {uuid} not found')
|
raise ValueError(f'Embedding model {uuid} not found')
|
||||||
|
|
||||||
|
async def get_rerank_model_by_uuid(self, uuid: str) -> requester.RuntimeRerankModel:
|
||||||
|
"""Get rerank model by uuid"""
|
||||||
|
for model in self.rerank_models:
|
||||||
|
if model.model_entity.uuid == uuid:
|
||||||
|
return model
|
||||||
|
raise ValueError(f'Rerank model {uuid} not found')
|
||||||
|
|
||||||
async def remove_llm_model(self, model_uuid: str):
|
async def remove_llm_model(self, model_uuid: str):
|
||||||
"""Remove LLM model"""
|
"""Remove LLM model"""
|
||||||
for model in self.llm_models:
|
for model in self.llm_models:
|
||||||
@@ -382,6 +443,13 @@ class ModelManager:
|
|||||||
self.embedding_models.remove(model)
|
self.embedding_models.remove(model)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
async def remove_rerank_model(self, model_uuid: str):
|
||||||
|
"""Remove rerank model"""
|
||||||
|
for model in self.rerank_models:
|
||||||
|
if model.model_entity.uuid == model_uuid:
|
||||||
|
self.rerank_models.remove(model)
|
||||||
|
return
|
||||||
|
|
||||||
def get_available_requesters_info(self, model_type: str) -> list[dict]:
|
def get_available_requesters_info(self, model_type: str) -> list[dict]:
|
||||||
"""Get all available requesters"""
|
"""Get all available requesters"""
|
||||||
if model_type != '':
|
if model_type != '':
|
||||||
|
|||||||
@@ -247,6 +247,40 @@ class RuntimeProvider:
|
|||||||
except Exception as monitor_err:
|
except Exception as monitor_err:
|
||||||
self.requester.ap.logger.error(f'[Monitoring] Failed to record embedding call: {monitor_err}')
|
self.requester.ap.logger.error(f'[Monitoring] Failed to record embedding call: {monitor_err}')
|
||||||
|
|
||||||
|
async def invoke_rerank(
|
||||||
|
self,
|
||||||
|
model: RuntimeRerankModel,
|
||||||
|
query: str,
|
||||||
|
documents: typing.List[str],
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
) -> typing.List[dict]:
|
||||||
|
"""Bridge method for invoking rerank with monitoring"""
|
||||||
|
start_time = time.time()
|
||||||
|
status = 'success'
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.requester.invoke_rerank(
|
||||||
|
model=model,
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
extra_args=extra_args,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
status = 'error'
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.requester.ap.logger.debug(
|
||||||
|
f'[Rerank] model={model.model_entity.name} docs={len(documents)} '
|
||||||
|
f'duration={duration_ms}ms status={status}'
|
||||||
|
)
|
||||||
|
except Exception as monitor_err:
|
||||||
|
self.requester.ap.logger.error(f'[Monitoring] Failed to record rerank call: {monitor_err}')
|
||||||
|
|
||||||
|
|
||||||
class RuntimeLLMModel:
|
class RuntimeLLMModel:
|
||||||
"""运行时模型"""
|
"""运行时模型"""
|
||||||
@@ -284,6 +318,24 @@ class RuntimeEmbeddingModel:
|
|||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeRerankModel:
|
||||||
|
"""运行时 Rerank 模型"""
|
||||||
|
|
||||||
|
model_entity: persistence_model.RerankModel
|
||||||
|
"""模型数据"""
|
||||||
|
|
||||||
|
provider: RuntimeProvider
|
||||||
|
"""提供商实例"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_entity: persistence_model.RerankModel,
|
||||||
|
provider: RuntimeProvider,
|
||||||
|
):
|
||||||
|
self.model_entity = model_entity
|
||||||
|
self.provider = provider
|
||||||
|
|
||||||
|
|
||||||
class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||||
"""Provider API请求器"""
|
"""Provider API请求器"""
|
||||||
|
|
||||||
@@ -376,3 +428,23 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
|||||||
或者 tuple[typing.List[typing.List[float]], dict]: 返回 (embedding 向量, usage_info)
|
或者 tuple[typing.List[typing.List[float]], dict]: 返回 (embedding 向量, usage_info)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def invoke_rerank(
|
||||||
|
self,
|
||||||
|
model: RuntimeRerankModel,
|
||||||
|
query: str,
|
||||||
|
documents: typing.List[str],
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
) -> typing.List[dict]:
|
||||||
|
"""调用 Rerank API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (RuntimeRerankModel): 使用的模型信息
|
||||||
|
query (str): 查询文本
|
||||||
|
documents (typing.List[str]): 待重排序的文档列表
|
||||||
|
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
typing.List[dict]: [{"index": int, "relevance_score": float}, ...]
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('This requester does not support rerank')
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ spec:
|
|||||||
default: 120
|
default: 120
|
||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -615,3 +615,88 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
|||||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||||
except openai.APIError as e:
|
except openai.APIError as e:
|
||||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||||
|
|
||||||
|
async def invoke_rerank(
|
||||||
|
self,
|
||||||
|
model: requester.RuntimeRerankModel,
|
||||||
|
query: str,
|
||||||
|
documents: typing.List[str],
|
||||||
|
extra_args: dict[str, typing.Any] = {},
|
||||||
|
) -> typing.List[dict]:
|
||||||
|
"""Standard /rerank endpoint (Jina/Cohere/SiliconFlow/Voyage/DashScope compatible)
|
||||||
|
|
||||||
|
Supports extra_args from model.extra_args:
|
||||||
|
- rerank_url: full URL override (e.g. "https://dashscope.aliyuncs.com/compatible-api/v1/reranks")
|
||||||
|
- rerank_path: path override appended to base_url (e.g. "reranks" instead of default "rerank")
|
||||||
|
- Any other fields are merged into the request payload.
|
||||||
|
"""
|
||||||
|
api_key = model.provider.token_mgr.get_token()
|
||||||
|
base_url = self.requester_cfg.get('base_url', '').rstrip('/')
|
||||||
|
timeout = self.requester_cfg.get('timeout', 120)
|
||||||
|
|
||||||
|
merged_args = {}
|
||||||
|
if model.model_entity.extra_args:
|
||||||
|
merged_args.update(model.model_entity.extra_args)
|
||||||
|
if extra_args:
|
||||||
|
merged_args.update(extra_args)
|
||||||
|
|
||||||
|
rerank_url = merged_args.pop('rerank_url', None)
|
||||||
|
rerank_path = merged_args.pop('rerank_path', 'rerank')
|
||||||
|
if not rerank_url:
|
||||||
|
rerank_url = f'{base_url}/{rerank_path}'
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': f'Bearer {api_key}',
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
'model': model.model_entity.name,
|
||||||
|
'query': query,
|
||||||
|
'documents': documents[:64],
|
||||||
|
'top_n': min(len(documents), 64),
|
||||||
|
}
|
||||||
|
|
||||||
|
if merged_args:
|
||||||
|
payload.update(merged_args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(trust_env=True, timeout=timeout) as client:
|
||||||
|
resp = await client.post(rerank_url, headers=headers, json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
results = self._parse_rerank_response(data)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
scores = [r.get('relevance_score', 0.0) for r in results]
|
||||||
|
min_score = min(scores)
|
||||||
|
max_score = max(scores)
|
||||||
|
if max_score - min_score > 1e-6:
|
||||||
|
for r in results:
|
||||||
|
r['relevance_score'] = (r['relevance_score'] - min_score) / (max_score - min_score)
|
||||||
|
|
||||||
|
return results
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise errors.RequesterError(f'Rerank request failed: {e.response.status_code} - {e.response.text}')
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise errors.RequesterError('Rerank request timed out')
|
||||||
|
except Exception as e:
|
||||||
|
raise errors.RequesterError(f'Rerank request error: {str(e)}')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_rerank_response(data: dict) -> typing.List[dict]:
|
||||||
|
"""Parse rerank response from various providers.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Jina/Cohere/SiliconFlow: {"results": [{"index", "relevance_score"}]}
|
||||||
|
- Voyage AI: {"data": [{"index", "relevance_score"}]}
|
||||||
|
- DashScope: {"output": {"results": [{"index", "relevance_score"}]}}
|
||||||
|
"""
|
||||||
|
if 'results' in data:
|
||||||
|
return data['results']
|
||||||
|
if 'data' in data:
|
||||||
|
return data['data']
|
||||||
|
if 'output' in data and isinstance(data['output'], dict):
|
||||||
|
return data['output'].get('results', [])
|
||||||
|
return []
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: manufacturer
|
provider_category: manufacturer
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 128 128" id="Chroma--Streamline-Svg-Logos" height="128" width="128">
|
||||||
<rect width="24" height="24" rx="5" fill="#7B68EE"/>
|
<desc>
|
||||||
<circle cx="12" cy="12" r="6" fill="#FF6B35"/>
|
Chroma Streamline Icon: https://streamlinehq.com
|
||||||
<circle cx="12" cy="12" r="3" fill="#7B68EE"/>
|
</desc>
|
||||||
<path d="M12 6V18" stroke="#FFF" stroke-width="1.5" stroke-linecap="round"/>
|
<path fill="#ffde2d" d="M84.88839999999999 104.10666666666665c23.0732 0 41.77773333333333 -17.956266666666664 41.77773333333333 -40.10653333333333 0 -22.150266666666667 -18.70453333333333 -40.10653333333333 -41.77773333333333 -40.10653333333333 -23.0732 0 -41.77773333333333 17.956266666666664 -41.77773333333333 40.10653333333333 0 22.150266666666667 18.70453333333333 40.10653333333333 41.77773333333333 40.10653333333333Z" stroke-width="1.3333"></path>
|
||||||
<path d="M6 12H18" stroke="#FFF" stroke-width="1.5" stroke-linecap="round"/>
|
<path fill="#327eff" d="M43.111066666666666 104.10666666666665c23.0732 0 41.77773333333333 -17.956266666666664 41.77773333333333 -40.10653333333333 0 -22.150266666666667 -18.70453333333333 -40.10653333333333 -41.77773333333333 -40.10653333333333C20.037866666666666 23.8936 1.3333333333333333 41.849866666666664 1.3333333333333333 64.00013333333334 1.3333333333333333 86.15039999999999 20.037866666666666 104.10666666666665 43.111066666666666 104.10666666666665Z" stroke-width="1.3333"></path>
|
||||||
|
<path fill="#ff6446" d="M84.88866666666667 64.00013333333334c0 22.150399999999998 -18.704666666666665 40.10626666666666 -41.778 40.10626666666666V64.00013333333334h41.778Zm-41.778 0c0 -22.150266666666667 18.70453333333333 -40.10653333333333 41.778 -40.10653333333333v40.10653333333333H43.11066666666666Z" stroke-width="1.3333"></path>
|
||||||
</svg>
|
</svg>
|
||||||
|
Before Width: | Height: | Size: 413 B After Width: | Height: | Size: 1.5 KiB |
1
src/langbot/pkg/provider/modelmgr/requesters/cohere.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Cohere</title><path clip-rule="evenodd" d="M8.128 14.099c.592 0 1.77-.033 3.398-.703 1.897-.781 5.672-2.2 8.395-3.656 1.905-1.018 2.74-2.366 2.74-4.18A4.56 4.56 0 0018.1 1H7.549A6.55 6.55 0 001 7.55c0 3.617 2.745 6.549 7.128 6.549z" fill="#39594D" fill-rule="evenodd"></path><path clip-rule="evenodd" d="M9.912 18.61a4.387 4.387 0 012.705-4.052l3.323-1.38c3.361-1.394 7.06 1.076 7.06 4.715a5.104 5.104 0 01-5.105 5.104l-3.597-.001a4.386 4.386 0 01-4.386-4.387z" fill="#D18EE2" fill-rule="evenodd"></path><path d="M4.776 14.962A3.775 3.775 0 001 18.738v.489a3.776 3.776 0 007.551 0v-.49a3.775 3.775 0 00-3.775-3.775z" fill="#FF7759"></path></svg>
|
||||||
|
After Width: | Height: | Size: 769 B |
@@ -0,0 +1,31 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: cohere-rerank
|
||||||
|
label:
|
||||||
|
en_US: Cohere
|
||||||
|
zh_Hans: Cohere
|
||||||
|
icon: cohere.svg
|
||||||
|
spec:
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.cohere.com/v2
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
|
execution:
|
||||||
|
python:
|
||||||
|
path: ./chatcmpl.py
|
||||||
|
attr: OpenAIChatCompletions
|
||||||
@@ -25,6 +25,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
1
src/langbot/pkg/provider/modelmgr/requesters/jina.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg fill="currentColor" fill-rule="evenodd" height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Jina</title><path d="M6.608 21.416a4.608 4.608 0 100-9.217 4.608 4.608 0 000 9.217zM20.894 2.015c.614 0 1.106.492 1.106 1.106v9.002c0 5.13-4.148 9.309-9.217 9.37v-9.355l-.03-9.032c0-.614.491-1.106 1.106-1.106h7.158l-.123.015z"></path></svg>
|
||||||
|
After Width: | Height: | Size: 404 B |
31
src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: jina-rerank
|
||||||
|
label:
|
||||||
|
en_US: Jina
|
||||||
|
zh_Hans: Jina
|
||||||
|
icon: jina.svg
|
||||||
|
spec:
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.jina.ai/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
|
execution:
|
||||||
|
python:
|
||||||
|
path: ./chatcmpl.py
|
||||||
|
attr: OpenAIChatCompletions
|
||||||
@@ -25,6 +25,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -1,8 +1,17 @@
|
|||||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg id="_图层_1" data-name="图层 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 334.84 76.22">
|
||||||
<rect width="24" height="24" rx="5" fill="#1E3A5F"/>
|
<defs>
|
||||||
<path d="M6 12C6 8.68629 8.68629 6 12 6C15.3137 6 18 8.68629 18 12" stroke="#4FC3F7" stroke-width="2" stroke-linecap="round"/>
|
<style>
|
||||||
<path d="M18 12C18 15.3137 15.3137 18 12 18C8.68629 18 6 15.3137 6 12" stroke="#81D4FA" stroke-width="2" stroke-linecap="round"/>
|
.cls-1 {
|
||||||
<circle cx="12" cy="12" r="2" fill="#4FC3F7"/>
|
fill: currentColor;
|
||||||
<circle cx="6" cy="12" r="1.5" fill="#81D4FA"/>
|
}
|
||||||
<circle cx="18" cy="12" r="1.5" fill="#4FC3F7"/>
|
</style>
|
||||||
</svg>
|
</defs>
|
||||||
|
<path class="cls-1" d="M308.56,23.63c-5.04,0-9.73,1.43-13.73,3.88V1.08l-12.56,4.61v70h12.56v-3.35c4,2.46,8.71,3.88,13.73,3.88,14.49,0,26.29-11.79,26.29-26.29s-11.79-26.29-26.29-26.29h0ZM308.56,63.88c-6.87,0-12.57-4.98-13.73-11.51v-4.91c1.16-6.54,6.88-11.51,13.73-11.51,7.7,0,13.96,6.26,13.96,13.96s-6.26,13.96-13.96,13.96Z"></path>
|
||||||
|
<path class="cls-1" d="M255.54,5.69v21.83c-4-2.46-8.71-3.88-13.73-3.88-14.49,0-26.29,11.79-26.29,26.29s11.79,26.29,26.29,26.29c5.04,0,9.73-1.43,13.73-3.88v3.35h12.56V1.08l-12.56,4.61ZM241.81,63.88c-7.7,0-13.96-6.26-13.96-13.96s6.26-13.96,13.96-13.96c6.87,0,12.57,4.98,13.73,11.51v4.91c-1.16,6.54-6.88,11.51-13.73,11.51Z"></path>
|
||||||
|
<polygon class="cls-1" points="195.35 52.2 186.65 61.17 200.64 75.62 209.32 75.62 218.01 75.62 195.35 52.2"></polygon>
|
||||||
|
<path class="cls-1" d="M167.14,4.59c.65,3.99.68,8.04.03,12.15-.03.17.16.3.31.21,3.82-2.21,7.82-3.69,12.01-4.33.12-.02.19-.13.17-.23-.68-4.13-.61-8.18-.03-12.16.02-.17-.16-.3-.31-.2-4.01,2.31-8.01,3.81-12.01,4.34-.12.01-.19.12-.17.23h0Z"></path>
|
||||||
|
<path class="cls-1" d="M198.75,24.09l-19.07,19.72v-25.57c-4.49.67-8.7,2.11-12.56,4.57v52.83h12.56v-13.87l3.78-3.9.02.02,8.68-8.97-.02-.02,23.98-24.8h-17.37Z"></path>
|
||||||
|
<path class="cls-1" d="M145.03,57.86c-2.56,4.45-7.17,7.2-12.13,7.2-5.96,0-11.3-3.96-13.32-9.85h38.87l.08-.42c.29-1.5.42-3.06.42-4.65,0-14.37-11.69-26.06-26.06-26.06s-26.06,11.69-26.06,26.06,11.69,26.06,26.06,26.06c9.63,0,18.43-5.28,22.98-13.77l.26-.49-11.1-4.08h-.01ZM132.88,35.19h.03c5.96,0,11.3,3.96,13.32,9.85h-26.67c2.02-5.89,7.36-9.85,13.32-9.85Z"></path>
|
||||||
|
<path class="cls-1" d="M75.92,65.07c-5.96,0-11.29-3.96-13.32-9.85h38.87l.08-.42c.29-1.5.42-3.06.42-4.65,0-14.37-11.69-26.06-26.06-26.06s-26.06,11.69-26.06,26.06,11.69,26.06,26.06,26.06c9.63,0,18.43-5.28,22.98-13.77l.26-.49h0l-11.1-4.08c-2.56,4.45-7.17,7.2-12.13,7.2h-.01ZM75.92,35.19h.03c5.96,0,11.29,3.96,13.32,9.85h-26.67c2.03-5.89,7.36-9.85,13.32-9.85Z"></path>
|
||||||
|
<path class="cls-1" d="M30.43,45.58l-10.2-1.91c-3.03-.56-4.98-2.25-4.98-4.33,0-1.5,1.61-4.35,7.68-4.35,5.53,0,9.36,3.5,10.25,6.26l10.9-4-.14-.42c-1.17-3.54-3.5-6.58-6.94-9.04-3.49-2.49-8.04-3.69-13.88-3.69s-10.98,1.5-14.78,4.34c-3.88,2.91-5.84,6.76-5.84,11.46,0,7.98,4.72,12.77,14.42,14.64l9.9,1.81c3.05.61,4.94,2.27,4.94,4.33,0,2.61-3.58,4.44-8.7,4.44-5.79,0-9.9-3.72-11.85-7.14L0,62.1l.14.39c1.3,3.8,3.89,7.07,7.7,9.71,3.78,2.6,8.65,3.95,14.51,3.98l.25.03c6.87,0,12.55-1.57,16.43-4.53,3.98-3.05,6-6.99,6-11.74,0-3.73-1.14-6.7-3.6-9.33-2.27-2.42-5.98-4.11-10.98-5.02h-.02Z"></path>
|
||||||
|
</svg>
|
||||||
|
Before Width: | Height: | Size: 569 B After Width: | Height: | Size: 2.7 KiB |
@@ -25,6 +25,7 @@ spec:
|
|||||||
support_type:
|
support_type:
|
||||||
- llm
|
- llm
|
||||||
- text-embedding
|
- text-embedding
|
||||||
|
- rerank
|
||||||
provider_category: maas
|
provider_category: maas
|
||||||
execution:
|
execution:
|
||||||
python:
|
python:
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Voyage</title><path d="M5.407 0v.066a.974.974 0 00-.048.245c-.011.11-.016.208-.016.295 0 .339.043.715.128 1.13.097.405.274.912.531 1.524l7.125 16.366L20.011 3.39c.161-.404.333-.846.515-1.327.182-.48.273-.966.273-1.458a1.406 1.406 0 00-.096-.54V0H24v.066c-.204.207-.45.578-.74 1.114-.29.535-.606 1.195-.949 1.982L13.095 24h-1.287L3.075 3.965c-.204-.47-.418-.923-.644-1.36-.214-.437-.418-.83-.61-1.18-.194-.36-.365-.66-.515-.9A5.666 5.666 0 001 .064V0h4.407z" fill="#012E33"></path></svg>
|
||||||
|
After Width: | Height: | Size: 610 B |
@@ -0,0 +1,31 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: LLMAPIRequester
|
||||||
|
metadata:
|
||||||
|
name: voyageai-rerank
|
||||||
|
label:
|
||||||
|
en_US: Voyage AI
|
||||||
|
zh_Hans: Voyage AI
|
||||||
|
icon: voyageai.svg
|
||||||
|
spec:
|
||||||
|
config:
|
||||||
|
- name: base_url
|
||||||
|
label:
|
||||||
|
en_US: Base URL
|
||||||
|
zh_Hans: 基础 URL
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
default: https://api.voyageai.com/v1
|
||||||
|
- name: timeout
|
||||||
|
label:
|
||||||
|
en_US: Timeout
|
||||||
|
zh_Hans: 超时时间
|
||||||
|
type: integer
|
||||||
|
required: true
|
||||||
|
default: 120
|
||||||
|
support_type:
|
||||||
|
- rerank
|
||||||
|
provider_category: manufacturer
|
||||||
|
execution:
|
||||||
|
python:
|
||||||
|
path: ./chatcmpl.py
|
||||||
|
attr: OpenAIChatCompletions
|
||||||
@@ -172,6 +172,45 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
if result:
|
if result:
|
||||||
all_results.extend(result)
|
all_results.extend(result)
|
||||||
|
|
||||||
|
# Rerank step: re-score results using a rerank model if configured
|
||||||
|
local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {})
|
||||||
|
rerank_model_uuid = local_agent_config.get('rerank-model', '')
|
||||||
|
if rerank_model_uuid == '__none__':
|
||||||
|
rerank_model_uuid = ''
|
||||||
|
self.ap.logger.info(
|
||||||
|
f'Rerank config: model_uuid={rerank_model_uuid!r}, '
|
||||||
|
f'results={len(all_results)}, '
|
||||||
|
f'local_agent_keys={list(local_agent_config.keys())}'
|
||||||
|
)
|
||||||
|
if all_results and rerank_model_uuid:
|
||||||
|
try:
|
||||||
|
rerank_model = await self.ap.model_mgr.get_rerank_model_by_uuid(rerank_model_uuid)
|
||||||
|
rerank_top_k = int(local_agent_config.get('rerank-top-k', 5))
|
||||||
|
|
||||||
|
doc_texts = []
|
||||||
|
for entry in all_results:
|
||||||
|
text = ' '.join(c.text for c in entry.content if c.type == 'text' and c.text)
|
||||||
|
doc_texts.append(text)
|
||||||
|
|
||||||
|
doc_texts_capped = doc_texts[:64]
|
||||||
|
scores = await rerank_model.provider.invoke_rerank(
|
||||||
|
model=rerank_model,
|
||||||
|
query=user_message_text,
|
||||||
|
documents=doc_texts_capped,
|
||||||
|
)
|
||||||
|
|
||||||
|
scored = sorted(scores, key=lambda x: x.get('relevance_score', 0), reverse=True)
|
||||||
|
top_indices = [s['index'] for s in scored[:rerank_top_k] if s['index'] < len(all_results)]
|
||||||
|
all_results = [all_results[i] for i in top_indices]
|
||||||
|
|
||||||
|
self.ap.logger.info(
|
||||||
|
f'Rerank complete: {len(doc_texts)} docs reranked -> top {len(all_results)} kept (top_k={rerank_top_k})'
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
self.ap.logger.warning(f'Rerank model {rerank_model_uuid} not found, skipping rerank')
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.warning(f'Rerank failed, using original order: {e}')
|
||||||
|
|
||||||
final_user_message_text = ''
|
final_user_message_text = ''
|
||||||
|
|
||||||
if all_results:
|
if all_results:
|
||||||
|
|||||||
@@ -52,7 +52,9 @@
|
|||||||
"content": "You are a helpful assistant."
|
"content": "You are a helpful assistant."
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"knowledge-bases": []
|
"knowledge-bases": [],
|
||||||
|
"rerank-model": "",
|
||||||
|
"rerank-top-k": 5
|
||||||
},
|
},
|
||||||
"dify-service-api": {
|
"dify-service-api": {
|
||||||
"base-url": "https://api.dify.ai/v1",
|
"base-url": "https://api.dify.ai/v1",
|
||||||
|
|||||||
@@ -104,6 +104,34 @@ stages:
|
|||||||
field: __system.is_wizard
|
field: __system.is_wizard
|
||||||
operator: neq
|
operator: neq
|
||||||
value: true
|
value: true
|
||||||
|
- name: rerank-model
|
||||||
|
label:
|
||||||
|
en_US: Rerank Model
|
||||||
|
zh_Hans: 重排序模型
|
||||||
|
description:
|
||||||
|
en_US: Optional rerank model to improve retrieval quality by re-scoring retrieved chunks
|
||||||
|
zh_Hans: 可选的重排序模型,通过重新评分检索结果来提升检索质量
|
||||||
|
type: rerank-model-selector
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
show_if:
|
||||||
|
field: knowledge-bases
|
||||||
|
operator: neq
|
||||||
|
value: []
|
||||||
|
- name: rerank-top-k
|
||||||
|
label:
|
||||||
|
en_US: Rerank Top K
|
||||||
|
zh_Hans: 重排序保留数量
|
||||||
|
description:
|
||||||
|
en_US: Number of top results to keep after reranking
|
||||||
|
zh_Hans: 重排序后保留的最相关结果数量
|
||||||
|
type: integer
|
||||||
|
required: false
|
||||||
|
default: 5
|
||||||
|
show_if:
|
||||||
|
field: rerank-model
|
||||||
|
operator: neq
|
||||||
|
value: ''
|
||||||
- name: dify-service-api
|
- name: dify-service-api
|
||||||
label:
|
label:
|
||||||
en_US: Dify Service API
|
en_US: Dify Service API
|
||||||
|
|||||||
@@ -240,6 +240,9 @@ export default function DynamicFormComponent({
|
|||||||
case 'embedding-model-selector':
|
case 'embedding-model-selector':
|
||||||
fieldSchema = z.string();
|
fieldSchema = z.string();
|
||||||
break;
|
break;
|
||||||
|
case 'rerank-model-selector':
|
||||||
|
fieldSchema = z.string();
|
||||||
|
break;
|
||||||
case 'knowledge-base-selector':
|
case 'knowledge-base-selector':
|
||||||
fieldSchema = z.string();
|
fieldSchema = z.string();
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import {
|
|||||||
Bot,
|
Bot,
|
||||||
KnowledgeBase,
|
KnowledgeBase,
|
||||||
EmbeddingModel,
|
EmbeddingModel,
|
||||||
|
RerankModel,
|
||||||
PluginTool,
|
PluginTool,
|
||||||
} from '@/app/infra/entities/api';
|
} from '@/app/infra/entities/api';
|
||||||
import { toast } from 'sonner';
|
import { toast } from 'sonner';
|
||||||
@@ -74,6 +75,7 @@ export default function DynamicFormItemComponent({
|
|||||||
}) {
|
}) {
|
||||||
const [llmModels, setLlmModels] = useState<LLMModel[]>([]);
|
const [llmModels, setLlmModels] = useState<LLMModel[]>([]);
|
||||||
const [embeddingModels, setEmbeddingModels] = useState<EmbeddingModel[]>([]);
|
const [embeddingModels, setEmbeddingModels] = useState<EmbeddingModel[]>([]);
|
||||||
|
const [rerankModels, setRerankModels] = useState<RerankModel[]>([]);
|
||||||
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBase[]>([]);
|
const [knowledgeBases, setKnowledgeBases] = useState<KnowledgeBase[]>([]);
|
||||||
const [bots, setBots] = useState<Bot[]>([]);
|
const [bots, setBots] = useState<Bot[]>([]);
|
||||||
const [tools, setTools] = useState<PluginTool[]>([]);
|
const [tools, setTools] = useState<PluginTool[]>([]);
|
||||||
@@ -180,6 +182,19 @@ export default function DynamicFormItemComponent({
|
|||||||
}
|
}
|
||||||
}, [config.type]);
|
}, [config.type]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (config.type === DynamicFormItemType.RERANK_MODEL_SELECTOR) {
|
||||||
|
httpClient
|
||||||
|
.getProviderRerankModels()
|
||||||
|
.then((resp) => {
|
||||||
|
setRerankModels(resp.models);
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
toast.error('Failed to load rerank models: ' + err.msg);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [config.type]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (config.type === DynamicFormItemType.MODEL_FALLBACK_SELECTOR) {
|
if (config.type === DynamicFormItemType.MODEL_FALLBACK_SELECTOR) {
|
||||||
fetchLlmModels();
|
fetchLlmModels();
|
||||||
@@ -585,6 +600,45 @@ export default function DynamicFormItemComponent({
|
|||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
case DynamicFormItemType.RERANK_MODEL_SELECTOR:
|
||||||
|
const groupedRerankModels = rerankModels.reduce(
|
||||||
|
(acc, model) => {
|
||||||
|
const providerName = model.provider?.name || 'Unknown';
|
||||||
|
if (!acc[providerName]) acc[providerName] = [];
|
||||||
|
acc[providerName].push(model);
|
||||||
|
return acc;
|
||||||
|
},
|
||||||
|
{} as Record<string, RerankModel[]>,
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-md">
|
||||||
|
<Select
|
||||||
|
value={field.value || '__none__'}
|
||||||
|
onValueChange={(v) => field.onChange(v === '__none__' ? '' : v)}
|
||||||
|
>
|
||||||
|
<SelectTrigger className="bg-[#ffffff] dark:bg-[#2a2a2e]">
|
||||||
|
<SelectValue placeholder={t('models.rerank')} />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
<SelectItem value="__none__">{t('common.none')}</SelectItem>
|
||||||
|
{Object.entries(groupedRerankModels).map(
|
||||||
|
([providerName, models]) => (
|
||||||
|
<SelectGroup key={providerName}>
|
||||||
|
<SelectLabel>{providerName}</SelectLabel>
|
||||||
|
{models.map((model) => (
|
||||||
|
<SelectItem key={model.uuid} value={model.uuid}>
|
||||||
|
{model.name}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
),
|
||||||
|
)}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
case DynamicFormItemType.MODEL_FALLBACK_SELECTOR: {
|
case DynamicFormItemType.MODEL_FALLBACK_SELECTOR: {
|
||||||
// Separate space models from regular models
|
// Separate space models from regular models
|
||||||
const fbSpaceModels = llmModels.filter(
|
const fbSpaceModels = llmModels.filter(
|
||||||
|
|||||||
@@ -147,15 +147,17 @@ export default function ModelsDialog({
|
|||||||
setLoadingProviders((prev) => new Set(prev).add(providerUuid));
|
setLoadingProviders((prev) => new Set(prev).add(providerUuid));
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
const [llmResp, embeddingResp] = await Promise.all([
|
const [llmResp, embeddingResp, rerankResp] = await Promise.all([
|
||||||
httpClient.getProviderLLMModels(providerUuid),
|
httpClient.getProviderLLMModels(providerUuid),
|
||||||
httpClient.getProviderEmbeddingModels(providerUuid),
|
httpClient.getProviderEmbeddingModels(providerUuid),
|
||||||
|
httpClient.getProviderRerankModels(providerUuid),
|
||||||
]);
|
]);
|
||||||
setProviderModels((prev) => ({
|
setProviderModels((prev) => ({
|
||||||
...prev,
|
...prev,
|
||||||
[providerUuid]: {
|
[providerUuid]: {
|
||||||
llm: llmResp.models,
|
llm: llmResp.models,
|
||||||
embedding: embeddingResp.models,
|
embedding: embeddingResp.models,
|
||||||
|
rerank: rerankResp.models,
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
@@ -247,12 +249,18 @@ export default function ModelsDialog({
|
|||||||
abilities,
|
abilities,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
} else {
|
} else if (modelType === 'embedding') {
|
||||||
await httpClient.createProviderEmbeddingModel({
|
await httpClient.createProviderEmbeddingModel({
|
||||||
name,
|
name,
|
||||||
provider_uuid: providerUuid,
|
provider_uuid: providerUuid,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
|
} else {
|
||||||
|
await httpClient.createProviderRerankModel({
|
||||||
|
name,
|
||||||
|
provider_uuid: providerUuid,
|
||||||
|
extra_args: extraArgsObj,
|
||||||
|
} as never);
|
||||||
}
|
}
|
||||||
setAddModelPopoverOpen(null);
|
setAddModelPopoverOpen(null);
|
||||||
loadProviderModels(providerUuid, true);
|
loadProviderModels(providerUuid, true);
|
||||||
@@ -341,12 +349,18 @@ export default function ModelsDialog({
|
|||||||
abilities,
|
abilities,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
} else {
|
} else if (modelType === 'embedding') {
|
||||||
await httpClient.updateProviderEmbeddingModel(modelId, {
|
await httpClient.updateProviderEmbeddingModel(modelId, {
|
||||||
name,
|
name,
|
||||||
provider_uuid: providerUuid,
|
provider_uuid: providerUuid,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
|
} else {
|
||||||
|
await httpClient.updateProviderRerankModel(modelId, {
|
||||||
|
name,
|
||||||
|
provider_uuid: providerUuid,
|
||||||
|
extra_args: extraArgsObj,
|
||||||
|
} as never);
|
||||||
}
|
}
|
||||||
setEditModelPopoverOpen(null);
|
setEditModelPopoverOpen(null);
|
||||||
loadProviderModels(providerUuid, true);
|
loadProviderModels(providerUuid, true);
|
||||||
@@ -366,8 +380,10 @@ export default function ModelsDialog({
|
|||||||
try {
|
try {
|
||||||
if (modelType === 'llm') {
|
if (modelType === 'llm') {
|
||||||
await httpClient.deleteProviderLLMModel(modelId);
|
await httpClient.deleteProviderLLMModel(modelId);
|
||||||
} else {
|
} else if (modelType === 'embedding') {
|
||||||
await httpClient.deleteProviderEmbeddingModel(modelId);
|
await httpClient.deleteProviderEmbeddingModel(modelId);
|
||||||
|
} else {
|
||||||
|
await httpClient.deleteProviderRerankModel(modelId);
|
||||||
}
|
}
|
||||||
toast.success(t('models.deleteSuccess'));
|
toast.success(t('models.deleteSuccess'));
|
||||||
loadProviderModels(providerUuid, true);
|
loadProviderModels(providerUuid, true);
|
||||||
@@ -407,7 +423,7 @@ export default function ModelsDialog({
|
|||||||
abilities,
|
abilities,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
} else {
|
} else if (modelType === 'embedding') {
|
||||||
await httpClient.testEmbeddingModel('_', {
|
await httpClient.testEmbeddingModel('_', {
|
||||||
uuid: '',
|
uuid: '',
|
||||||
name,
|
name,
|
||||||
@@ -415,6 +431,14 @@ export default function ModelsDialog({
|
|||||||
provider: providerData,
|
provider: providerData,
|
||||||
extra_args: extraArgsObj,
|
extra_args: extraArgsObj,
|
||||||
} as never);
|
} as never);
|
||||||
|
} else {
|
||||||
|
await httpClient.testRerankModel('_', {
|
||||||
|
uuid: '',
|
||||||
|
name,
|
||||||
|
provider_uuid: '',
|
||||||
|
provider: providerData,
|
||||||
|
extra_args: extraArgsObj,
|
||||||
|
} as never);
|
||||||
}
|
}
|
||||||
const duration = Date.now() - startTime;
|
const duration = Date.now() - startTime;
|
||||||
setTestResult({ success: true, duration });
|
setTestResult({ success: true, duration });
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import {
|
|||||||
Plus,
|
Plus,
|
||||||
MessageSquareText,
|
MessageSquareText,
|
||||||
Cpu,
|
Cpu,
|
||||||
|
ArrowUpDown,
|
||||||
Eye,
|
Eye,
|
||||||
Wrench,
|
Wrench,
|
||||||
Check,
|
Check,
|
||||||
@@ -265,7 +266,7 @@ export default function AddModelPopover({
|
|||||||
onClick={(e) => e.stopPropagation()}
|
onClick={(e) => e.stopPropagation()}
|
||||||
>
|
>
|
||||||
<Tabs value={tab} onValueChange={(v) => setTab(v as ModelType)}>
|
<Tabs value={tab} onValueChange={(v) => setTab(v as ModelType)}>
|
||||||
<TabsList className="grid w-full grid-cols-2">
|
<TabsList className="grid w-full grid-cols-3">
|
||||||
<TabsTrigger value="llm">
|
<TabsTrigger value="llm">
|
||||||
<MessageSquareText className="h-4 w-4 mr-1" />
|
<MessageSquareText className="h-4 w-4 mr-1" />
|
||||||
{t('models.chat')}
|
{t('models.chat')}
|
||||||
@@ -274,6 +275,10 @@ export default function AddModelPopover({
|
|||||||
<Cpu className="h-4 w-4 mr-1" />
|
<Cpu className="h-4 w-4 mr-1" />
|
||||||
{t('models.embedding')}
|
{t('models.embedding')}
|
||||||
</TabsTrigger>
|
</TabsTrigger>
|
||||||
|
<TabsTrigger value="rerank">
|
||||||
|
<ArrowUpDown className="h-4 w-4 mr-1" />
|
||||||
|
{t('models.rerank')}
|
||||||
|
</TabsTrigger>
|
||||||
</TabsList>
|
</TabsList>
|
||||||
|
|
||||||
<Tabs
|
<Tabs
|
||||||
@@ -330,7 +335,11 @@ export default function AddModelPopover({
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<ExtraArgsEditor args={extraArgs} onChange={setExtraArgs} />
|
<ExtraArgsEditor
|
||||||
|
args={extraArgs}
|
||||||
|
onChange={setExtraArgs}
|
||||||
|
modelType={tab}
|
||||||
|
/>
|
||||||
<div className="flex gap-2">
|
<div className="flex gap-2">
|
||||||
<Button
|
<Button
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
@@ -467,7 +476,9 @@ export default function AddModelPopover({
|
|||||||
? t('models.alreadyAdded')
|
? t('models.alreadyAdded')
|
||||||
: model.type === 'llm'
|
: model.type === 'llm'
|
||||||
? t('models.chat')
|
? t('models.chat')
|
||||||
: t('models.embedding')}
|
: model.type === 'embedding'
|
||||||
|
? t('models.embedding')
|
||||||
|
: t('models.rerank')}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { Plus, X } from 'lucide-react';
|
import { Plus, X, HelpCircle } from 'lucide-react';
|
||||||
import { Button } from '@/components/ui/button';
|
import { Button } from '@/components/ui/button';
|
||||||
import { Input } from '@/components/ui/input';
|
import { Input } from '@/components/ui/input';
|
||||||
import { Label } from '@/components/ui/label';
|
import { Label } from '@/components/ui/label';
|
||||||
@@ -9,19 +9,26 @@ import {
|
|||||||
SelectTrigger,
|
SelectTrigger,
|
||||||
SelectValue,
|
SelectValue,
|
||||||
} from '@/components/ui/select';
|
} from '@/components/ui/select';
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from '@/components/ui/tooltip';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ExtraArg } from '../types';
|
import { ExtraArg, ModelType } from '../types';
|
||||||
|
|
||||||
interface ExtraArgsEditorProps {
|
interface ExtraArgsEditorProps {
|
||||||
args: ExtraArg[];
|
args: ExtraArg[];
|
||||||
onChange: (args: ExtraArg[]) => void;
|
onChange: (args: ExtraArg[]) => void;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
|
modelType?: ModelType;
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function ExtraArgsEditor({
|
export default function ExtraArgsEditor({
|
||||||
args,
|
args,
|
||||||
onChange,
|
onChange,
|
||||||
disabled = false,
|
disabled = false,
|
||||||
|
modelType,
|
||||||
}: ExtraArgsEditorProps) {
|
}: ExtraArgsEditorProps) {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@@ -46,7 +53,27 @@ export default function ExtraArgsEditor({
|
|||||||
return (
|
return (
|
||||||
<div className="space-y-2">
|
<div className="space-y-2">
|
||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-between">
|
||||||
<Label>{t('models.extraParameters')}</Label>
|
<div className="flex items-center gap-1">
|
||||||
|
<Label>{t('models.extraParameters')}</Label>
|
||||||
|
{modelType === 'rerank' && (
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<HelpCircle className="h-4 w-4 text-muted-foreground cursor-help" />
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent className="max-w-xs">
|
||||||
|
<div className="space-y-1 text-sm">
|
||||||
|
<p>
|
||||||
|
<strong>rerank_url</strong>: {t('models.rerankUrlTooltip')}
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
<strong>rerank_path</strong>:{' '}
|
||||||
|
{t('models.rerankPathTooltip')}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
{!disabled && (
|
{!disabled && (
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
|
|||||||
@@ -139,7 +139,11 @@ export default function ModelItem({
|
|||||||
<div className="flex items-center gap-2 flex-wrap">
|
<div className="flex items-center gap-2 flex-wrap">
|
||||||
<span className="text-sm font-medium">{model.name}</span>
|
<span className="text-sm font-medium">{model.name}</span>
|
||||||
<Badge variant="secondary" className="text-xs">
|
<Badge variant="secondary" className="text-xs">
|
||||||
{modelType === 'llm' ? t('models.chat') : t('models.embedding')}
|
{modelType === 'llm'
|
||||||
|
? t('models.chat')
|
||||||
|
: modelType === 'embedding'
|
||||||
|
? t('models.embedding')
|
||||||
|
: t('models.rerank')}
|
||||||
</Badge>
|
</Badge>
|
||||||
{modelType === 'llm' &&
|
{modelType === 'llm' &&
|
||||||
(model as LLMModel).abilities?.includes('vision') && (
|
(model as LLMModel).abilities?.includes('vision') && (
|
||||||
@@ -263,6 +267,7 @@ export default function ModelItem({
|
|||||||
args={editExtraArgs}
|
args={editExtraArgs}
|
||||||
onChange={setEditExtraArgs}
|
onChange={setEditExtraArgs}
|
||||||
disabled={isLangBotModels}
|
disabled={isLangBotModels}
|
||||||
|
modelType={modelType}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<div className="flex gap-2">
|
<div className="flex gap-2">
|
||||||
|
|||||||
@@ -134,9 +134,12 @@ export default function ProviderCard({
|
|||||||
const canDelete =
|
const canDelete =
|
||||||
!isLangBotModels &&
|
!isLangBotModels &&
|
||||||
(provider.llm_count || 0) === 0 &&
|
(provider.llm_count || 0) === 0 &&
|
||||||
(provider.embedding_count || 0) === 0;
|
(provider.embedding_count || 0) === 0 &&
|
||||||
|
(provider.rerank_count || 0) === 0;
|
||||||
const totalModels =
|
const totalModels =
|
||||||
(provider.llm_count || 0) + (provider.embedding_count || 0);
|
(provider.llm_count || 0) +
|
||||||
|
(provider.embedding_count || 0) +
|
||||||
|
(provider.rerank_count || 0);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Card className="mb-2">
|
<Card className="mb-2">
|
||||||
@@ -393,11 +396,44 @@ export default function ProviderCard({
|
|||||||
onResetTestResult={onResetTestResult}
|
onResetTestResult={onResetTestResult}
|
||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
{models.llm.length === 0 && models.embedding.length === 0 && (
|
{models.rerank.map((model) => (
|
||||||
<p className="text-sm text-muted-foreground text-center py-4">
|
<ModelItem
|
||||||
{t('models.noModels')}
|
key={model.uuid}
|
||||||
</p>
|
model={model}
|
||||||
)}
|
modelType="rerank"
|
||||||
|
isLangBotModels={isLangBotModels}
|
||||||
|
editModelPopoverOpen={editModelPopoverOpen}
|
||||||
|
deleteConfirmOpen={deleteConfirmOpen}
|
||||||
|
onOpenEditModel={onOpenEditModel}
|
||||||
|
onCloseEditModel={onCloseEditModel}
|
||||||
|
onOpenDeleteConfirm={onOpenDeleteConfirm}
|
||||||
|
onCloseDeleteConfirm={onCloseDeleteConfirm}
|
||||||
|
onDeleteModel={() => onDeleteModel(model.uuid, 'rerank')}
|
||||||
|
onUpdateModel={(name, abilities, extraArgs) =>
|
||||||
|
onUpdateModel(
|
||||||
|
model.uuid,
|
||||||
|
'rerank',
|
||||||
|
name,
|
||||||
|
abilities,
|
||||||
|
extraArgs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
onTestModel={(name, abilities, extraArgs) =>
|
||||||
|
onTestModel(name, 'rerank', abilities, extraArgs)
|
||||||
|
}
|
||||||
|
isSubmitting={isSubmitting}
|
||||||
|
isTesting={isTesting}
|
||||||
|
testResult={testResult}
|
||||||
|
onResetTestResult={onResetTestResult}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
{models.llm.length === 0 &&
|
||||||
|
models.embedding.length === 0 &&
|
||||||
|
models.rerank.length === 0 && (
|
||||||
|
<p className="text-sm text-muted-foreground text-center py-4">
|
||||||
|
{t('models.noModels')}
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<p className="text-sm text-muted-foreground text-center py-4">
|
<p className="text-sm text-muted-foreground text-center py-4">
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import {
|
import {
|
||||||
LLMModel,
|
LLMModel,
|
||||||
EmbeddingModel,
|
EmbeddingModel,
|
||||||
|
RerankModel,
|
||||||
ModelProvider,
|
ModelProvider,
|
||||||
ProviderScanDebugInfo,
|
ProviderScanDebugInfo,
|
||||||
ScannedProviderModel,
|
ScannedProviderModel,
|
||||||
@@ -12,11 +13,12 @@ export type ExtraArg = {
|
|||||||
value: string;
|
value: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ModelType = 'llm' | 'embedding';
|
export type ModelType = 'llm' | 'embedding' | 'rerank';
|
||||||
|
|
||||||
export interface ProviderModels {
|
export interface ProviderModels {
|
||||||
llm: LLMModel[];
|
llm: LLMModel[];
|
||||||
embedding: EmbeddingModel[];
|
embedding: EmbeddingModel[];
|
||||||
|
rerank: RerankModel[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TestResult {
|
export interface TestResult {
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ export interface ModelProvider {
|
|||||||
api_keys: string[];
|
api_keys: string[];
|
||||||
llm_count?: number;
|
llm_count?: number;
|
||||||
embedding_count?: number;
|
embedding_count?: number;
|
||||||
|
rerank_count?: number;
|
||||||
created_at?: string;
|
created_at?: string;
|
||||||
updated_at?: string;
|
updated_at?: string;
|
||||||
}
|
}
|
||||||
@@ -114,6 +115,22 @@ export interface EmbeddingModel {
|
|||||||
extra_args?: object;
|
extra_args?: object;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ApiRespProviderRerankModels {
|
||||||
|
models: RerankModel[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ApiRespProviderRerankModel {
|
||||||
|
model: RerankModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RerankModel {
|
||||||
|
uuid: string;
|
||||||
|
name: string;
|
||||||
|
provider_uuid: string;
|
||||||
|
provider?: ModelProvider;
|
||||||
|
extra_args?: object;
|
||||||
|
}
|
||||||
|
|
||||||
export interface ApiRespPipelines {
|
export interface ApiRespPipelines {
|
||||||
pipelines: Pipeline[];
|
pipelines: Pipeline[];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ export enum DynamicFormItemType {
|
|||||||
SELECT = 'select',
|
SELECT = 'select',
|
||||||
LLM_MODEL_SELECTOR = 'llm-model-selector',
|
LLM_MODEL_SELECTOR = 'llm-model-selector',
|
||||||
EMBEDDING_MODEL_SELECTOR = 'embedding-model-selector',
|
EMBEDDING_MODEL_SELECTOR = 'embedding-model-selector',
|
||||||
|
RERANK_MODEL_SELECTOR = 'rerank-model-selector',
|
||||||
MODEL_FALLBACK_SELECTOR = 'model-fallback-selector',
|
MODEL_FALLBACK_SELECTOR = 'model-fallback-selector',
|
||||||
PROMPT_EDITOR = 'prompt-editor',
|
PROMPT_EDITOR = 'prompt-editor',
|
||||||
UNKNOWN = 'unknown',
|
UNKNOWN = 'unknown',
|
||||||
|
|||||||
@@ -31,6 +31,9 @@ import {
|
|||||||
ApiRespProviderEmbeddingModels,
|
ApiRespProviderEmbeddingModels,
|
||||||
ApiRespProviderEmbeddingModel,
|
ApiRespProviderEmbeddingModel,
|
||||||
EmbeddingModel,
|
EmbeddingModel,
|
||||||
|
ApiRespProviderRerankModels,
|
||||||
|
ApiRespProviderRerankModel,
|
||||||
|
RerankModel,
|
||||||
ApiRespPluginSystemStatus,
|
ApiRespPluginSystemStatus,
|
||||||
ApiRespMCPServers,
|
ApiRespMCPServers,
|
||||||
ApiRespMCPServer,
|
ApiRespMCPServer,
|
||||||
@@ -182,6 +185,39 @@ export class BackendClient extends BaseHttpClient {
|
|||||||
return this.post(`/api/v1/provider/models/embedding/${uuid}/test`, model);
|
return this.post(`/api/v1/provider/models/embedding/${uuid}/test`, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============ Provider Model Rerank ============
|
||||||
|
public getProviderRerankModels(
|
||||||
|
providerUuid?: string,
|
||||||
|
): Promise<ApiRespProviderRerankModels> {
|
||||||
|
const params = providerUuid ? { provider_uuid: providerUuid } : {};
|
||||||
|
return this.get('/api/v1/provider/models/rerank', params);
|
||||||
|
}
|
||||||
|
|
||||||
|
public getProviderRerankModel(
|
||||||
|
uuid: string,
|
||||||
|
): Promise<ApiRespProviderRerankModel> {
|
||||||
|
return this.get(`/api/v1/provider/models/rerank/${uuid}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
public createProviderRerankModel(model: RerankModel): Promise<object> {
|
||||||
|
return this.post('/api/v1/provider/models/rerank', model);
|
||||||
|
}
|
||||||
|
|
||||||
|
public deleteProviderRerankModel(uuid: string): Promise<object> {
|
||||||
|
return this.delete(`/api/v1/provider/models/rerank/${uuid}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
public updateProviderRerankModel(
|
||||||
|
uuid: string,
|
||||||
|
model: RerankModel,
|
||||||
|
): Promise<object> {
|
||||||
|
return this.put(`/api/v1/provider/models/rerank/${uuid}`, model);
|
||||||
|
}
|
||||||
|
|
||||||
|
public testRerankModel(uuid: string, model: RerankModel): Promise<object> {
|
||||||
|
return this.post(`/api/v1/provider/models/rerank/${uuid}/test`, model);
|
||||||
|
}
|
||||||
|
|
||||||
// ============ Pipeline API ============
|
// ============ Pipeline API ============
|
||||||
public getGeneralPipelineMetadata(): Promise<GetPipelineMetadataResponseData> {
|
public getGeneralPipelineMetadata(): Promise<GetPipelineMetadataResponseData> {
|
||||||
// as designed, this method will be deprecated, and only for developer to check the prefered config schema
|
// as designed, this method will be deprecated, and only for developer to check the prefered config schema
|
||||||
|
|||||||
@@ -271,6 +271,10 @@ const enUS = {
|
|||||||
loadError: 'Failed to load data',
|
loadError: 'Failed to load data',
|
||||||
chat: 'Chat',
|
chat: 'Chat',
|
||||||
embedding: 'Embedding',
|
embedding: 'Embedding',
|
||||||
|
rerank: 'Rerank',
|
||||||
|
rerankUrlTooltip:
|
||||||
|
'Full URL override for rerank endpoint (e.g. https://dashscope.aliyuncs.com/compatible-api/v1/reranks)',
|
||||||
|
rerankPathTooltip: 'Path appended to base URL (default: rerank)',
|
||||||
modelsCount: '{{count}} model(s)',
|
modelsCount: '{{count}} model(s)',
|
||||||
expandModels: 'Expand',
|
expandModels: 'Expand',
|
||||||
collapseModels: 'Collapse',
|
collapseModels: 'Collapse',
|
||||||
|
|||||||
@@ -260,6 +260,10 @@ const zhHans = {
|
|||||||
loadError: '加载数据失败',
|
loadError: '加载数据失败',
|
||||||
chat: '对话',
|
chat: '对话',
|
||||||
embedding: '嵌入',
|
embedding: '嵌入',
|
||||||
|
rerank: '重排序',
|
||||||
|
rerankUrlTooltip:
|
||||||
|
'重排序接口的完整 URL 覆盖(如 https://dashscope.aliyuncs.com/compatible-api/v1/reranks)',
|
||||||
|
rerankPathTooltip: '添加到基础 URL 后的重排序路径(默认:rerank)',
|
||||||
modelsCount: '{{count}} 个模型',
|
modelsCount: '{{count}} 个模型',
|
||||||
expandModels: '展开',
|
expandModels: '展开',
|
||||||
collapseModels: '收起',
|
collapseModels: '收起',
|
||||||
|
|||||||