From 323481d69b3b83c7d3cd5aa988505ad4c61067ea Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <1051233107@qq.com> Date: Mon, 20 Apr 2026 23:32:36 +0800 Subject: [PATCH] Feat/rerank model (#2137) * feat(provider): add rerank model management as a core model type * feat(provider): add rerank support to existing requesters and new rerank providers * feat(web): add rerank model management UI and pipeline config * fix(provider): correct rerank support_type after verification - Add rerank to OpenRouter (confirmed /api/v1/rerank endpoint) - Remove rerank from Ollama (no native support, PR #7219 unmerged) - Remove rerank from JiekouAI (no rerank docs found, URL path mismatch) * fix(provider): remove alru_cache from model getters and add rerank param hints * fix: resolve lint errors - Remove unused alru_cache import from modelmgr.py - Remove unused error_message variable in invoke_rerank - Fix prettier formatting in frontend files Co-Authored-By: Claude Opus 4.7 * fix: remove unused exception variable - Change `except Exception as e:` to `except Exception:` since e is not used - Fix prettier formatting in ProviderCard.tsx Co-Authored-By: Claude Opus 4.7 * fix: apply ruff format Co-Authored-By: Claude Opus 4.7 * feat(template): add rerank config fields to default pipeline config Co-Authored-By: Claude Opus 4.7 * chore: remove PR.md Co-Authored-By: Claude Opus 4.7 * fix(ui): remove duplicate rerank model form in AddModelPopover The form was being rendered twice: once in TabsContent manual mode and again in a separate conditional block for rerank tab. Co-Authored-By: Claude Opus 4.7 --------- Co-authored-by: Claude Opus 4.7 --- .../http/controller/groups/provider/models.py | 48 ++++++ .../controller/groups/provider/providers.py | 2 + src/langbot/pkg/api/http/service/model.py | 159 ++++++++++++++++++ src/langbot/pkg/api/http/service/provider.py | 17 +- src/langbot/pkg/core/app.py | 2 + src/langbot/pkg/core/stages/build_app.py | 3 + src/langbot/pkg/entity/persistence/model.py | 19 +++ .../versions/0003_add_rerank_models.py | 31 ++++ src/langbot/pkg/provider/modelmgr/modelmgr.py | 78 ++++++++- .../pkg/provider/modelmgr/requester.py | 72 ++++++++ .../modelmgr/requesters/302aichatcmpl.yaml | 1 + .../modelmgr/requesters/bailianchatcmpl.yaml | 1 + .../provider/modelmgr/requesters/chatcmpl.py | 85 ++++++++++ .../modelmgr/requesters/chatcmpl.yaml | 1 + .../provider/modelmgr/requesters/chroma.svg | 13 +- .../provider/modelmgr/requesters/cohere.svg | 1 + .../modelmgr/requesters/coherererank.yaml | 31 ++++ .../modelmgr/requesters/giteeaichatcmpl.yaml | 1 + .../pkg/provider/modelmgr/requesters/jina.svg | 1 + .../modelmgr/requesters/jinarerank.yaml | 31 ++++ .../requesters/openrouterchatcmpl.yaml | 1 + .../provider/modelmgr/requesters/seekdb.svg | 25 ++- .../requesters/siliconflowchatcmpl.yaml | 1 + .../provider/modelmgr/requesters/voyageai.svg | 1 + .../modelmgr/requesters/voyageairerank.yaml | 31 ++++ .../pkg/provider/runners/localagent.py | 39 +++++ .../templates/default-pipeline-config.json | 4 +- .../templates/metadata/pipeline/ai.yaml | 28 +++ .../dynamic-form/DynamicFormComponent.tsx | 3 + .../dynamic-form/DynamicFormItemComponent.tsx | 54 ++++++ .../components/models-dialog/ModelsDialog.tsx | 34 +++- .../components/AddModelPopover.tsx | 17 +- .../components/ExtraArgsEditor.tsx | 33 +++- .../models-dialog/components/ModelItem.tsx | 7 +- .../models-dialog/components/ProviderCard.tsx | 50 +++++- .../home/components/models-dialog/types.ts | 4 +- web/src/app/infra/entities/api/index.ts | 17 ++ web/src/app/infra/entities/form/dynamic.ts | 1 + web/src/app/infra/http/BackendClient.ts | 36 ++++ web/src/i18n/locales/en-US.ts | 4 + web/src/i18n/locales/zh-Hans.ts | 4 + 41 files changed, 950 insertions(+), 41 deletions(-) create mode 100644 src/langbot/pkg/persistence/alembic/versions/0003_add_rerank_models.py create mode 100644 src/langbot/pkg/provider/modelmgr/requesters/cohere.svg create mode 100644 src/langbot/pkg/provider/modelmgr/requesters/coherererank.yaml create mode 100644 src/langbot/pkg/provider/modelmgr/requesters/jina.svg create mode 100644 src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml create mode 100644 src/langbot/pkg/provider/modelmgr/requesters/voyageai.svg create mode 100644 src/langbot/pkg/provider/modelmgr/requesters/voyageairerank.yaml diff --git a/src/langbot/pkg/api/http/controller/groups/provider/models.py b/src/langbot/pkg/api/http/controller/groups/provider/models.py index cec582ee..f683c98f 100644 --- a/src/langbot/pkg/api/http/controller/groups/provider/models.py +++ b/src/langbot/pkg/api/http/controller/groups/provider/models.py @@ -97,3 +97,51 @@ class EmbeddingModelsRouterGroup(group.RouterGroup): await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data) return self.success() + + +@group.group_class('models/rerank', '/api/v1/provider/models/rerank') +class RerankModelsRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) + async def _() -> str: + if quart.request.method == 'GET': + provider_uuid = quart.request.args.get('provider_uuid') + if provider_uuid: + return self.success( + data={ + 'models': await self.ap.rerank_models_service.get_rerank_models_by_provider(provider_uuid) + } + ) + return self.success(data={'models': await self.ap.rerank_models_service.get_rerank_models()}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + model_uuid = await self.ap.rerank_models_service.create_rerank_model(json_data) + return self.success(data={'uuid': model_uuid}) + + @self.route('/', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) + async def _(model_uuid: str) -> str: + if quart.request.method == 'GET': + model = await self.ap.rerank_models_service.get_rerank_model(model_uuid) + + if model is None: + return self.http_status(404, -1, 'model not found') + + return self.success(data={'model': model}) + elif quart.request.method == 'PUT': + json_data = await quart.request.json + + await self.ap.rerank_models_service.update_rerank_model(model_uuid, json_data) + + return self.success() + elif quart.request.method == 'DELETE': + await self.ap.rerank_models_service.delete_rerank_model(model_uuid) + + return self.success() + + @self.route('//test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) + async def _(model_uuid: str) -> str: + json_data = await quart.request.json + + await self.ap.rerank_models_service.test_rerank_model(model_uuid, json_data) + + return self.success() diff --git a/src/langbot/pkg/api/http/controller/groups/provider/providers.py b/src/langbot/pkg/api/http/controller/groups/provider/providers.py index d303f178..fcea598f 100644 --- a/src/langbot/pkg/api/http/controller/groups/provider/providers.py +++ b/src/langbot/pkg/api/http/controller/groups/provider/providers.py @@ -15,6 +15,7 @@ class ModelProvidersRouterGroup(group.RouterGroup): counts = await self.ap.provider_service.get_provider_model_counts(provider['uuid']) provider['llm_count'] = counts['llm_count'] provider['embedding_count'] = counts['embedding_count'] + provider['rerank_count'] = counts['rerank_count'] return self.success(data={'providers': providers}) elif quart.request.method == 'POST': json_data = await quart.request.json @@ -32,6 +33,7 @@ class ModelProvidersRouterGroup(group.RouterGroup): counts = await self.ap.provider_service.get_provider_model_counts(provider_uuid) provider['llm_count'] = counts['llm_count'] provider['embedding_count'] = counts['embedding_count'] + provider['rerank_count'] = counts['rerank_count'] return self.success(data={'provider': provider}) elif quart.request.method == 'PUT': json_data = await quart.request.json diff --git a/src/langbot/pkg/api/http/service/model.py b/src/langbot/pkg/api/http/service/model.py index f10dcd02..b670e99d 100644 --- a/src/langbot/pkg/api/http/service/model.py +++ b/src/langbot/pkg/api/http/service/model.py @@ -367,3 +367,162 @@ class EmbeddingModelsService: input_text=['Hello, world!'], extra_args={}, ) + + +class RerankModelsService: + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_rerank_models(self) -> list[dict]: + """Get all rerank models with provider info""" + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel)) + models = result.all() + + providers_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider) + ) + providers = {p.uuid: p for p in providers_result.all()} + + models_list = [] + for model in models: + model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model) + provider = providers.get(model.provider_uuid) + if provider: + provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider) + model_dict['provider'] = _parse_provider_api_keys(provider_dict) + models_list.append(model_dict) + + return models_list + + async def get_rerank_models_by_provider(self, provider_uuid: str) -> list[dict]: + """Get rerank models by provider UUID""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.RerankModel).where( + persistence_model.RerankModel.provider_uuid == provider_uuid + ) + ) + models = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, m) for m in models] + + async def create_rerank_model(self, model_data: dict, preserve_uuid: bool = False) -> str: + """Create a new rerank model""" + if not preserve_uuid: + model_data['uuid'] = str(uuid.uuid4()) + + if 'provider' in model_data: + provider_data = model_data.pop('provider') + if provider_data.get('uuid'): + model_data['provider_uuid'] = provider_data['uuid'] + else: + provider_uuid = await self.ap.provider_service.find_or_create_provider( + requester=provider_data.get('requester', ''), + base_url=provider_data.get('base_url', ''), + api_keys=provider_data.get('api_keys', []), + ) + model_data['provider_uuid'] = provider_uuid + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_model.RerankModel).values(**model_data) + ) + + runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid']) + if runtime_provider is None: + raise Exception('provider not found') + + runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider( + persistence_model.RerankModel(**model_data), + runtime_provider, + ) + self.ap.model_mgr.rerank_models.append(runtime_rerank_model) + + return model_data['uuid'] + + async def get_rerank_model(self, model_uuid: str) -> dict | None: + """Get a single rerank model with provider info""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid) + ) + model = result.first() + if model is None: + return None + + model_dict = self.ap.persistence_mgr.serialize_model(persistence_model.RerankModel, model) + + provider_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider).where( + persistence_model.ModelProvider.uuid == model.provider_uuid + ) + ) + provider = provider_result.first() + if provider: + provider_dict = self.ap.persistence_mgr.serialize_model(persistence_model.ModelProvider, provider) + model_dict['provider'] = _parse_provider_api_keys(provider_dict) + + return model_dict + + async def update_rerank_model(self, model_uuid: str, model_data: dict) -> None: + """Update an existing rerank model""" + if 'uuid' in model_data: + del model_data['uuid'] + + if 'provider' in model_data: + provider_data = model_data.pop('provider') + if provider_data.get('uuid'): + model_data['provider_uuid'] = provider_data['uuid'] + else: + provider_uuid = await self.ap.provider_service.find_or_create_provider( + requester=provider_data.get('requester', ''), + base_url=provider_data.get('base_url', ''), + api_keys=provider_data.get('api_keys', []), + ) + model_data['provider_uuid'] = provider_uuid + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_model.RerankModel) + .where(persistence_model.RerankModel.uuid == model_uuid) + .values(**model_data) + ) + + await self.ap.model_mgr.remove_rerank_model(model_uuid) + + runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid']) + if runtime_provider is None: + raise Exception('provider not found') + + runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider( + persistence_model.RerankModel(**model_data), + runtime_provider, + ) + self.ap.model_mgr.rerank_models.append(runtime_rerank_model) + + async def delete_rerank_model(self, model_uuid: str) -> None: + """Delete a rerank model""" + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_model.RerankModel).where(persistence_model.RerankModel.uuid == model_uuid) + ) + await self.ap.model_mgr.remove_rerank_model(model_uuid) + + async def test_rerank_model(self, model_uuid: str, model_data: dict) -> None: + """Test a rerank model""" + runtime_rerank_model: model_requester.RuntimeRerankModel | None = None + + if model_uuid != '_': + for model in self.ap.model_mgr.rerank_models: + if model.model_entity.uuid == model_uuid: + runtime_rerank_model = model + break + if runtime_rerank_model is None: + raise Exception('model not found') + else: + runtime_rerank_model = await self.ap.model_mgr.init_temporary_runtime_rerank_model(model_data) + + await runtime_rerank_model.provider.invoke_rerank( + model=runtime_rerank_model, + query='What is artificial intelligence?', + documents=[ + 'Artificial intelligence is a branch of computer science.', + 'The weather is nice today.', + ], + ) diff --git a/src/langbot/pkg/api/http/service/provider.py b/src/langbot/pkg/api/http/service/provider.py index 24354731..503bf957 100644 --- a/src/langbot/pkg/api/http/service/provider.py +++ b/src/langbot/pkg/api/http/service/provider.py @@ -98,6 +98,14 @@ class ModelProviderService: if embedding_result.first() is not None: raise ValueError('Cannot delete provider: Embedding models still reference it') + rerank_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.RerankModel).where( + persistence_model.RerankModel.provider_uuid == provider_uuid + ) + ) + if rerank_result.first() is not None: + raise ValueError('Cannot delete provider: Rerank models still reference it') + await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_model.ModelProvider).where( persistence_model.ModelProvider.uuid == provider_uuid @@ -122,7 +130,14 @@ class ModelProviderService: ) embedding_count = embedding_result.scalar() or 0 - return {'llm_count': llm_count, 'embedding_count': embedding_count} + rerank_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(sqlalchemy.func.count()) + .select_from(persistence_model.RerankModel) + .where(persistence_model.RerankModel.provider_uuid == provider_uuid) + ) + rerank_count = rerank_result.scalar() or 0 + + return {'llm_count': llm_count, 'embedding_count': embedding_count, 'rerank_count': rerank_count} async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str: """Find existing provider or create new one""" diff --git a/src/langbot/pkg/core/app.py b/src/langbot/pkg/core/app.py index e515cfb9..aa1acd61 100644 --- a/src/langbot/pkg/core/app.py +++ b/src/langbot/pkg/core/app.py @@ -133,6 +133,8 @@ class Application: embedding_models_service: model_service.EmbeddingModelsService = None + rerank_models_service: model_service.RerankModelsService = None + provider_service: provider_service.ModelProviderService = None pipeline_service: pipeline_service.PipelineService = None diff --git a/src/langbot/pkg/core/stages/build_app.py b/src/langbot/pkg/core/stages/build_app.py index 62f0ae7b..71ff4262 100644 --- a/src/langbot/pkg/core/stages/build_app.py +++ b/src/langbot/pkg/core/stages/build_app.py @@ -61,6 +61,9 @@ class BuildAppStage(stage.BootingStage): embedding_models_service_inst = model_service.EmbeddingModelsService(ap) ap.embedding_models_service = embedding_models_service_inst + rerank_models_service_inst = model_service.RerankModelsService(ap) + ap.rerank_models_service = rerank_models_service_inst + provider_service_inst = provider_service.ModelProviderService(ap) ap.provider_service = provider_service_inst diff --git a/src/langbot/pkg/entity/persistence/model.py b/src/langbot/pkg/entity/persistence/model.py index 8ac3bd18..3c96acd7 100644 --- a/src/langbot/pkg/entity/persistence/model.py +++ b/src/langbot/pkg/entity/persistence/model.py @@ -59,3 +59,22 @@ class EmbeddingModel(Base): server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now(), ) + + +class RerankModel(Base): + """Rerank model""" + + __tablename__ = 'rerank_models' + + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + provider_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + prefered_ranking = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/src/langbot/pkg/persistence/alembic/versions/0003_add_rerank_models.py b/src/langbot/pkg/persistence/alembic/versions/0003_add_rerank_models.py new file mode 100644 index 00000000..78fb7ce5 --- /dev/null +++ b/src/langbot/pkg/persistence/alembic/versions/0003_add_rerank_models.py @@ -0,0 +1,31 @@ +"""add rerank_models table + +Revision ID: 0003_add_rerank_models +Revises: 0002_sample +Create Date: 2026-04-19 +""" + +import sqlalchemy as sa +from alembic import op + +revision = '0003_add_rerank_models' +down_revision = '0002_sample' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + 'rerank_models', + sa.Column('uuid', sa.String(255), primary_key=True, unique=True), + sa.Column('name', sa.String(255), nullable=False), + sa.Column('provider_uuid', sa.String(255), nullable=False), + sa.Column('extra_args', sa.JSON, nullable=False, server_default='{}'), + sa.Column('prefered_ranking', sa.Integer, nullable=False, server_default='0'), + sa.Column('created_at', sa.DateTime, nullable=False, server_default=sa.func.now()), + sa.Column('updated_at', sa.DateTime, nullable=False, server_default=sa.func.now()), + ) + + +def downgrade() -> None: + op.drop_table('rerank_models') diff --git a/src/langbot/pkg/provider/modelmgr/modelmgr.py b/src/langbot/pkg/provider/modelmgr/modelmgr.py index bce10b6a..976f1263 100644 --- a/src/langbot/pkg/provider/modelmgr/modelmgr.py +++ b/src/langbot/pkg/provider/modelmgr/modelmgr.py @@ -9,7 +9,6 @@ from ...discover import engine from . import token from ...entity.persistence import model as persistence_model from ...entity.errors import provider as provider_errors -from async_lru import alru_cache class ModelManager: @@ -24,6 +23,8 @@ class ModelManager: embedding_models: list[requester.RuntimeEmbeddingModel] + rerank_models: list[requester.RuntimeRerankModel] + requester_components: list[engine.Component] requester_dict: dict[str, type[requester.ProviderAPIRequester]] @@ -32,6 +33,7 @@ class ModelManager: self.ap = ap self.llm_models = [] self.embedding_models = [] + self.rerank_models = [] self.requester_components = [] self.requester_dict = {} @@ -64,8 +66,7 @@ class ModelManager: self.llm_models = [] self.embedding_models = [] - - # Load all providers first + self.rerank_models = [] self.provider_dict = {} providers_result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_model.ModelProvider) @@ -110,6 +111,22 @@ class ModelManager: except Exception as e: self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}') + # Load rerank models + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.RerankModel)) + rerank_models = result.all() + for rerank_model in rerank_models: + try: + provider = self.provider_dict.get(rerank_model.provider_uuid) + if provider is None: + self.ap.logger.warning( + f'Provider {rerank_model.provider_uuid} not found for model {rerank_model.uuid}' + ) + continue + runtime_rerank_model = await self.load_rerank_model_with_provider(rerank_model, provider) + self.rerank_models.append(runtime_rerank_model) + except Exception as e: + self.ap.logger.error(f'Failed to load model {rerank_model.uuid}: {e}\n{traceback.format_exc()}') + async def sync_new_models_from_space(self): """Sync models from Space""" space_model_provider = await self.ap.persistence_mgr.execute_async( @@ -212,6 +229,26 @@ class ModelManager: return runtime_embedding_model + async def init_temporary_runtime_rerank_model( + self, + model_info: dict, + ) -> requester.RuntimeRerankModel: + """Initialize runtime rerank model from dict (for testing)""" + provider_info = model_info.get('provider', {}) + runtime_provider = await self.load_provider(provider_info) + + runtime_rerank_model = requester.RuntimeRerankModel( + model_entity=persistence_model.RerankModel( + uuid=model_info.get('uuid', ''), + name=model_info.get('name', ''), + provider_uuid='', + extra_args=model_info.get('extra_args', {}), + ), + provider=runtime_provider, + ) + + return runtime_rerank_model + async def load_provider( self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict ) -> requester.RuntimeProvider: @@ -269,6 +306,9 @@ class ModelManager: for model in self.embedding_models: if model.provider.provider_entity.uuid == provider_uuid: model.provider = new_runtime_provider + for model in self.rerank_models: + if model.provider.provider_entity.uuid == provider_uuid: + model.provider = new_runtime_provider # update ref in provider dict self.provider_dict[provider_uuid] = new_runtime_provider @@ -305,6 +345,22 @@ class ModelManager: return runtime_embedding_model + async def load_rerank_model_with_provider( + self, + model_info: persistence_model.RerankModel | sqlalchemy.Row, + provider: requester.RuntimeProvider, + ) -> requester.RuntimeRerankModel: + """Load rerank model with provider info""" + if isinstance(model_info, sqlalchemy.Row): + model_info = persistence_model.RerankModel(**model_info._mapping) + + runtime_rerank_model = requester.RuntimeRerankModel( + model_entity=model_info, + provider=provider, + ) + + return runtime_rerank_model + async def load_llm_model(self, model_info: dict): """Load LLM model from dict (with provider info)""" provider_info = model_info.get('provider', {}) @@ -352,7 +408,6 @@ class ModelManager: await self.load_embedding_model_with_provider(model_entity, provider_entity) - @alru_cache(ttl=60 * 5) async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel: """Get LLM model by uuid""" for model in self.llm_models: @@ -360,7 +415,6 @@ class ModelManager: return model raise ValueError(f'LLM model {uuid} not found') - @alru_cache(ttl=60 * 5) async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel: """Get embedding model by uuid""" for model in self.embedding_models: @@ -368,6 +422,13 @@ class ModelManager: return model raise ValueError(f'Embedding model {uuid} not found') + async def get_rerank_model_by_uuid(self, uuid: str) -> requester.RuntimeRerankModel: + """Get rerank model by uuid""" + for model in self.rerank_models: + if model.model_entity.uuid == uuid: + return model + raise ValueError(f'Rerank model {uuid} not found') + async def remove_llm_model(self, model_uuid: str): """Remove LLM model""" for model in self.llm_models: @@ -382,6 +443,13 @@ class ModelManager: self.embedding_models.remove(model) return + async def remove_rerank_model(self, model_uuid: str): + """Remove rerank model""" + for model in self.rerank_models: + if model.model_entity.uuid == model_uuid: + self.rerank_models.remove(model) + return + def get_available_requesters_info(self, model_type: str) -> list[dict]: """Get all available requesters""" if model_type != '': diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index 301bdfe9..08fee3ab 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -247,6 +247,40 @@ class RuntimeProvider: except Exception as monitor_err: self.requester.ap.logger.error(f'[Monitoring] Failed to record embedding call: {monitor_err}') + async def invoke_rerank( + self, + model: RuntimeRerankModel, + query: str, + documents: typing.List[str], + extra_args: dict[str, typing.Any] = {}, + ) -> typing.List[dict]: + """Bridge method for invoking rerank with monitoring""" + start_time = time.time() + status = 'success' + + try: + result = await self.requester.invoke_rerank( + model=model, + query=query, + documents=documents, + extra_args=extra_args, + ) + return result + + except Exception: + status = 'error' + raise + finally: + duration_ms = int((time.time() - start_time) * 1000) + + try: + self.requester.ap.logger.debug( + f'[Rerank] model={model.model_entity.name} docs={len(documents)} ' + f'duration={duration_ms}ms status={status}' + ) + except Exception as monitor_err: + self.requester.ap.logger.error(f'[Monitoring] Failed to record rerank call: {monitor_err}') + class RuntimeLLMModel: """运行时模型""" @@ -284,6 +318,24 @@ class RuntimeEmbeddingModel: self.provider = provider +class RuntimeRerankModel: + """运行时 Rerank 模型""" + + model_entity: persistence_model.RerankModel + """模型数据""" + + provider: RuntimeProvider + """提供商实例""" + + def __init__( + self, + model_entity: persistence_model.RerankModel, + provider: RuntimeProvider, + ): + self.model_entity = model_entity + self.provider = provider + + class ProviderAPIRequester(metaclass=abc.ABCMeta): """Provider API请求器""" @@ -376,3 +428,23 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): 或者 tuple[typing.List[typing.List[float]], dict]: 返回 (embedding 向量, usage_info) """ pass + + async def invoke_rerank( + self, + model: RuntimeRerankModel, + query: str, + documents: typing.List[str], + extra_args: dict[str, typing.Any] = {}, + ) -> typing.List[dict]: + """调用 Rerank API + + Args: + model (RuntimeRerankModel): 使用的模型信息 + query (str): 查询文本 + documents (typing.List[str]): 待重排序的文档列表 + extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. + + Returns: + typing.List[dict]: [{"index": int, "relevance_score": float}, ...] + """ + raise NotImplementedError('This requester does not support rerank') diff --git a/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml index 4fc22be4..e4f70cae 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml @@ -25,6 +25,7 @@ spec: support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml index 7c405232..fc5998c4 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml @@ -24,6 +24,7 @@ spec: default: 120 support_type: - llm + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py index 24f7a200..da24bda0 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -615,3 +615,88 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def invoke_rerank( + self, + model: requester.RuntimeRerankModel, + query: str, + documents: typing.List[str], + extra_args: dict[str, typing.Any] = {}, + ) -> typing.List[dict]: + """Standard /rerank endpoint (Jina/Cohere/SiliconFlow/Voyage/DashScope compatible) + + Supports extra_args from model.extra_args: + - rerank_url: full URL override (e.g. "https://dashscope.aliyuncs.com/compatible-api/v1/reranks") + - rerank_path: path override appended to base_url (e.g. "reranks" instead of default "rerank") + - Any other fields are merged into the request payload. + """ + api_key = model.provider.token_mgr.get_token() + base_url = self.requester_cfg.get('base_url', '').rstrip('/') + timeout = self.requester_cfg.get('timeout', 120) + + merged_args = {} + if model.model_entity.extra_args: + merged_args.update(model.model_entity.extra_args) + if extra_args: + merged_args.update(extra_args) + + rerank_url = merged_args.pop('rerank_url', None) + rerank_path = merged_args.pop('rerank_path', 'rerank') + if not rerank_url: + rerank_url = f'{base_url}/{rerank_path}' + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}', + } + + payload = { + 'model': model.model_entity.name, + 'query': query, + 'documents': documents[:64], + 'top_n': min(len(documents), 64), + } + + if merged_args: + payload.update(merged_args) + + try: + async with httpx.AsyncClient(trust_env=True, timeout=timeout) as client: + resp = await client.post(rerank_url, headers=headers, json=payload) + resp.raise_for_status() + data = resp.json() + + results = self._parse_rerank_response(data) + + if results: + scores = [r.get('relevance_score', 0.0) for r in results] + min_score = min(scores) + max_score = max(scores) + if max_score - min_score > 1e-6: + for r in results: + r['relevance_score'] = (r['relevance_score'] - min_score) / (max_score - min_score) + + return results + except httpx.HTTPStatusError as e: + raise errors.RequesterError(f'Rerank request failed: {e.response.status_code} - {e.response.text}') + except httpx.TimeoutException: + raise errors.RequesterError('Rerank request timed out') + except Exception as e: + raise errors.RequesterError(f'Rerank request error: {str(e)}') + + @staticmethod + def _parse_rerank_response(data: dict) -> typing.List[dict]: + """Parse rerank response from various providers. + + Handles: + - Jina/Cohere/SiliconFlow: {"results": [{"index", "relevance_score"}]} + - Voyage AI: {"data": [{"index", "relevance_score"}]} + - DashScope: {"output": {"results": [{"index", "relevance_score"}]}} + """ + if 'results' in data: + return data['results'] + if 'data' in data: + return data['data'] + if 'output' in data and isinstance(data['output'], dict): + return data['output'].get('results', []) + return [] diff --git a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.yaml index 4f588fb2..21bd6a05 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.yaml @@ -25,6 +25,7 @@ spec: support_type: - llm - text-embedding + - rerank provider_category: manufacturer execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/chroma.svg b/src/langbot/pkg/provider/modelmgr/requesters/chroma.svg index 15252321..b94c0d66 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/chroma.svg +++ b/src/langbot/pkg/provider/modelmgr/requesters/chroma.svg @@ -1,7 +1,8 @@ - - - - - - + + + Chroma Streamline Icon: https://streamlinehq.com + + + + \ No newline at end of file diff --git a/src/langbot/pkg/provider/modelmgr/requesters/cohere.svg b/src/langbot/pkg/provider/modelmgr/requesters/cohere.svg new file mode 100644 index 00000000..94bcb82c --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/cohere.svg @@ -0,0 +1 @@ +Cohere \ No newline at end of file diff --git a/src/langbot/pkg/provider/modelmgr/requesters/coherererank.yaml b/src/langbot/pkg/provider/modelmgr/requesters/coherererank.yaml new file mode 100644 index 00000000..f1ca209b --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/coherererank.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: cohere-rerank + label: + en_US: Cohere + zh_Hans: Cohere + icon: cohere.svg +spec: + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://api.cohere.com/v2 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + support_type: + - rerank + provider_category: manufacturer +execution: + python: + path: ./chatcmpl.py + attr: OpenAIChatCompletions diff --git a/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml index e818bd7a..b7b158a7 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml @@ -25,6 +25,7 @@ spec: support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/jina.svg b/src/langbot/pkg/provider/modelmgr/requesters/jina.svg new file mode 100644 index 00000000..fc996bf9 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/jina.svg @@ -0,0 +1 @@ +Jina \ No newline at end of file diff --git a/src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml b/src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml new file mode 100644 index 00000000..3b448e38 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/jinarerank.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: jina-rerank + label: + en_US: Jina + zh_Hans: Jina + icon: jina.svg +spec: + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://api.jina.ai/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + support_type: + - rerank + provider_category: manufacturer +execution: + python: + path: ./chatcmpl.py + attr: OpenAIChatCompletions diff --git a/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml index f1603200..71064dc0 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml @@ -25,6 +25,7 @@ spec: support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/seekdb.svg b/src/langbot/pkg/provider/modelmgr/requesters/seekdb.svg index d1daf9d1..01e77353 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/seekdb.svg +++ b/src/langbot/pkg/provider/modelmgr/requesters/seekdb.svg @@ -1,8 +1,17 @@ - - - - - - - - + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml b/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml index 28d3314a..11a2ffa3 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml +++ b/src/langbot/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml @@ -25,6 +25,7 @@ spec: support_type: - llm - text-embedding + - rerank provider_category: maas execution: python: diff --git a/src/langbot/pkg/provider/modelmgr/requesters/voyageai.svg b/src/langbot/pkg/provider/modelmgr/requesters/voyageai.svg new file mode 100644 index 00000000..71b21fc0 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/voyageai.svg @@ -0,0 +1 @@ +Voyage \ No newline at end of file diff --git a/src/langbot/pkg/provider/modelmgr/requesters/voyageairerank.yaml b/src/langbot/pkg/provider/modelmgr/requesters/voyageairerank.yaml new file mode 100644 index 00000000..a47b8d47 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/requesters/voyageairerank.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: LLMAPIRequester +metadata: + name: voyageai-rerank + label: + en_US: Voyage AI + zh_Hans: Voyage AI + icon: voyageai.svg +spec: + config: + - name: base_url + label: + en_US: Base URL + zh_Hans: 基础 URL + type: string + required: true + default: https://api.voyageai.com/v1 + - name: timeout + label: + en_US: Timeout + zh_Hans: 超时时间 + type: integer + required: true + default: 120 + support_type: + - rerank + provider_category: manufacturer +execution: + python: + path: ./chatcmpl.py + attr: OpenAIChatCompletions diff --git a/src/langbot/pkg/provider/runners/localagent.py b/src/langbot/pkg/provider/runners/localagent.py index 65f86678..b48e9cc3 100644 --- a/src/langbot/pkg/provider/runners/localagent.py +++ b/src/langbot/pkg/provider/runners/localagent.py @@ -172,6 +172,45 @@ class LocalAgentRunner(runner.RequestRunner): if result: all_results.extend(result) + # Rerank step: re-score results using a rerank model if configured + local_agent_config = query.pipeline_config.get('ai', {}).get('local-agent', {}) + rerank_model_uuid = local_agent_config.get('rerank-model', '') + if rerank_model_uuid == '__none__': + rerank_model_uuid = '' + self.ap.logger.info( + f'Rerank config: model_uuid={rerank_model_uuid!r}, ' + f'results={len(all_results)}, ' + f'local_agent_keys={list(local_agent_config.keys())}' + ) + if all_results and rerank_model_uuid: + try: + rerank_model = await self.ap.model_mgr.get_rerank_model_by_uuid(rerank_model_uuid) + rerank_top_k = int(local_agent_config.get('rerank-top-k', 5)) + + doc_texts = [] + for entry in all_results: + text = ' '.join(c.text for c in entry.content if c.type == 'text' and c.text) + doc_texts.append(text) + + doc_texts_capped = doc_texts[:64] + scores = await rerank_model.provider.invoke_rerank( + model=rerank_model, + query=user_message_text, + documents=doc_texts_capped, + ) + + scored = sorted(scores, key=lambda x: x.get('relevance_score', 0), reverse=True) + top_indices = [s['index'] for s in scored[:rerank_top_k] if s['index'] < len(all_results)] + all_results = [all_results[i] for i in top_indices] + + self.ap.logger.info( + f'Rerank complete: {len(doc_texts)} docs reranked -> top {len(all_results)} kept (top_k={rerank_top_k})' + ) + except ValueError: + self.ap.logger.warning(f'Rerank model {rerank_model_uuid} not found, skipping rerank') + except Exception as e: + self.ap.logger.warning(f'Rerank failed, using original order: {e}') + final_user_message_text = '' if all_results: diff --git a/src/langbot/templates/default-pipeline-config.json b/src/langbot/templates/default-pipeline-config.json index be0b6ef8..e40d3914 100644 --- a/src/langbot/templates/default-pipeline-config.json +++ b/src/langbot/templates/default-pipeline-config.json @@ -52,7 +52,9 @@ "content": "You are a helpful assistant." } ], - "knowledge-bases": [] + "knowledge-bases": [], + "rerank-model": "", + "rerank-top-k": 5 }, "dify-service-api": { "base-url": "https://api.dify.ai/v1", diff --git a/src/langbot/templates/metadata/pipeline/ai.yaml b/src/langbot/templates/metadata/pipeline/ai.yaml index 0bb201a7..a944c217 100644 --- a/src/langbot/templates/metadata/pipeline/ai.yaml +++ b/src/langbot/templates/metadata/pipeline/ai.yaml @@ -104,6 +104,34 @@ stages: field: __system.is_wizard operator: neq value: true + - name: rerank-model + label: + en_US: Rerank Model + zh_Hans: 重排序模型 + description: + en_US: Optional rerank model to improve retrieval quality by re-scoring retrieved chunks + zh_Hans: 可选的重排序模型,通过重新评分检索结果来提升检索质量 + type: rerank-model-selector + required: false + default: '' + show_if: + field: knowledge-bases + operator: neq + value: [] + - name: rerank-top-k + label: + en_US: Rerank Top K + zh_Hans: 重排序保留数量 + description: + en_US: Number of top results to keep after reranking + zh_Hans: 重排序后保留的最相关结果数量 + type: integer + required: false + default: 5 + show_if: + field: rerank-model + operator: neq + value: '' - name: dify-service-api label: en_US: Dify Service API diff --git a/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx b/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx index b6b21aab..e714d85b 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx +++ b/web/src/app/home/components/dynamic-form/DynamicFormComponent.tsx @@ -240,6 +240,9 @@ export default function DynamicFormComponent({ case 'embedding-model-selector': fieldSchema = z.string(); break; + case 'rerank-model-selector': + fieldSchema = z.string(); + break; case 'knowledge-base-selector': fieldSchema = z.string(); break; diff --git a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx index 7b574033..51831bde 100644 --- a/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx +++ b/web/src/app/home/components/dynamic-form/DynamicFormItemComponent.tsx @@ -23,6 +23,7 @@ import { Bot, KnowledgeBase, EmbeddingModel, + RerankModel, PluginTool, } from '@/app/infra/entities/api'; import { toast } from 'sonner'; @@ -74,6 +75,7 @@ export default function DynamicFormItemComponent({ }) { const [llmModels, setLlmModels] = useState([]); const [embeddingModels, setEmbeddingModels] = useState([]); + const [rerankModels, setRerankModels] = useState([]); const [knowledgeBases, setKnowledgeBases] = useState([]); const [bots, setBots] = useState([]); const [tools, setTools] = useState([]); @@ -180,6 +182,19 @@ export default function DynamicFormItemComponent({ } }, [config.type]); + useEffect(() => { + if (config.type === DynamicFormItemType.RERANK_MODEL_SELECTOR) { + httpClient + .getProviderRerankModels() + .then((resp) => { + setRerankModels(resp.models); + }) + .catch((err) => { + toast.error('Failed to load rerank models: ' + err.msg); + }); + } + }, [config.type]); + useEffect(() => { if (config.type === DynamicFormItemType.MODEL_FALLBACK_SELECTOR) { fetchLlmModels(); @@ -585,6 +600,45 @@ export default function DynamicFormItemComponent({ ); + case DynamicFormItemType.RERANK_MODEL_SELECTOR: + const groupedRerankModels = rerankModels.reduce( + (acc, model) => { + const providerName = model.provider?.name || 'Unknown'; + if (!acc[providerName]) acc[providerName] = []; + acc[providerName].push(model); + return acc; + }, + {} as Record, + ); + + return ( +
+ +
+ ); + case DynamicFormItemType.MODEL_FALLBACK_SELECTOR: { // Separate space models from regular models const fbSpaceModels = llmModels.filter( diff --git a/web/src/app/home/components/models-dialog/ModelsDialog.tsx b/web/src/app/home/components/models-dialog/ModelsDialog.tsx index 72694ad3..be9d3367 100644 --- a/web/src/app/home/components/models-dialog/ModelsDialog.tsx +++ b/web/src/app/home/components/models-dialog/ModelsDialog.tsx @@ -147,15 +147,17 @@ export default function ModelsDialog({ setLoadingProviders((prev) => new Set(prev).add(providerUuid)); } try { - const [llmResp, embeddingResp] = await Promise.all([ + const [llmResp, embeddingResp, rerankResp] = await Promise.all([ httpClient.getProviderLLMModels(providerUuid), httpClient.getProviderEmbeddingModels(providerUuid), + httpClient.getProviderRerankModels(providerUuid), ]); setProviderModels((prev) => ({ ...prev, [providerUuid]: { llm: llmResp.models, embedding: embeddingResp.models, + rerank: rerankResp.models, }, })); } catch (err) { @@ -247,12 +249,18 @@ export default function ModelsDialog({ abilities, extra_args: extraArgsObj, } as never); - } else { + } else if (modelType === 'embedding') { await httpClient.createProviderEmbeddingModel({ name, provider_uuid: providerUuid, extra_args: extraArgsObj, } as never); + } else { + await httpClient.createProviderRerankModel({ + name, + provider_uuid: providerUuid, + extra_args: extraArgsObj, + } as never); } setAddModelPopoverOpen(null); loadProviderModels(providerUuid, true); @@ -341,12 +349,18 @@ export default function ModelsDialog({ abilities, extra_args: extraArgsObj, } as never); - } else { + } else if (modelType === 'embedding') { await httpClient.updateProviderEmbeddingModel(modelId, { name, provider_uuid: providerUuid, extra_args: extraArgsObj, } as never); + } else { + await httpClient.updateProviderRerankModel(modelId, { + name, + provider_uuid: providerUuid, + extra_args: extraArgsObj, + } as never); } setEditModelPopoverOpen(null); loadProviderModels(providerUuid, true); @@ -366,8 +380,10 @@ export default function ModelsDialog({ try { if (modelType === 'llm') { await httpClient.deleteProviderLLMModel(modelId); - } else { + } else if (modelType === 'embedding') { await httpClient.deleteProviderEmbeddingModel(modelId); + } else { + await httpClient.deleteProviderRerankModel(modelId); } toast.success(t('models.deleteSuccess')); loadProviderModels(providerUuid, true); @@ -407,7 +423,7 @@ export default function ModelsDialog({ abilities, extra_args: extraArgsObj, } as never); - } else { + } else if (modelType === 'embedding') { await httpClient.testEmbeddingModel('_', { uuid: '', name, @@ -415,6 +431,14 @@ export default function ModelsDialog({ provider: providerData, extra_args: extraArgsObj, } as never); + } else { + await httpClient.testRerankModel('_', { + uuid: '', + name, + provider_uuid: '', + provider: providerData, + extra_args: extraArgsObj, + } as never); } const duration = Date.now() - startTime; setTestResult({ success: true, duration }); diff --git a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx index bdbf90a8..b4e3913c 100644 --- a/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx +++ b/web/src/app/home/components/models-dialog/components/AddModelPopover.tsx @@ -3,6 +3,7 @@ import { Plus, MessageSquareText, Cpu, + ArrowUpDown, Eye, Wrench, Check, @@ -265,7 +266,7 @@ export default function AddModelPopover({ onClick={(e) => e.stopPropagation()} > setTab(v as ModelType)}> - + {t('models.chat')} @@ -274,6 +275,10 @@ export default function AddModelPopover({ {t('models.embedding')} + + + {t('models.rerank')} + )} - +
diff --git a/web/src/app/home/components/models-dialog/components/ExtraArgsEditor.tsx b/web/src/app/home/components/models-dialog/components/ExtraArgsEditor.tsx index e00d8269..72a10c1f 100644 --- a/web/src/app/home/components/models-dialog/components/ExtraArgsEditor.tsx +++ b/web/src/app/home/components/models-dialog/components/ExtraArgsEditor.tsx @@ -1,4 +1,4 @@ -import { Plus, X } from 'lucide-react'; +import { Plus, X, HelpCircle } from 'lucide-react'; import { Button } from '@/components/ui/button'; import { Input } from '@/components/ui/input'; import { Label } from '@/components/ui/label'; @@ -9,19 +9,26 @@ import { SelectTrigger, SelectValue, } from '@/components/ui/select'; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/components/ui/tooltip'; import { useTranslation } from 'react-i18next'; -import { ExtraArg } from '../types'; +import { ExtraArg, ModelType } from '../types'; interface ExtraArgsEditorProps { args: ExtraArg[]; onChange: (args: ExtraArg[]) => void; disabled?: boolean; + modelType?: ModelType; } export default function ExtraArgsEditor({ args, onChange, disabled = false, + modelType, }: ExtraArgsEditorProps) { const { t } = useTranslation(); @@ -46,7 +53,27 @@ export default function ExtraArgsEditor({ return (
- +
+ + {modelType === 'rerank' && ( + + + + + +
+

+ rerank_url: {t('models.rerankUrlTooltip')} +

+

+ rerank_path:{' '} + {t('models.rerankPathTooltip')} +

+
+
+
+ )} +
{!disabled && (