diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index f6d8cde1..e74e975d 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -6,6 +6,7 @@ import sqlalchemy from ....core import app from ....entity.persistence import model as persistence_model +from ....entity.persistence import pipeline as persistence_pipeline class ModelsService: @@ -40,6 +41,19 @@ class ModelsService: await self.ap.model_mgr.load_llm_model(llm_model) + # check if default pipeline has no model bound + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True) + ) + pipeline = result.first() + if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '': + pipeline_config = pipeline.config + pipeline_config['ai']['local-agent']['model'] = model_data['uuid'] + pipeline_data = { + "config": pipeline_config + } + await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data) + return model_data['uuid'] async def get_llm_model(self, model_uuid: str) -> dict | None: