From 7cd03b02435f75f18f9402ed0b9430688bc67da2 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 28 Mar 2025 15:55:03 +0800 Subject: [PATCH] feat: bind pipeline with runtime manager --- pkg/api/http/service/pipeline.py | 14 ++++- pkg/core/app.py | 4 +- pkg/core/stages/build_app.py | 6 +- pkg/pipeline/pipelinemgr.py | 93 +++++++++++++++++++++++++++++++ pkg/pipeline/stage.py | 4 +- pkg/pipeline/stagemgr.py | 2 +- pkg/provider/modelmgr/modelmgr.py | 6 +- 7 files changed, 119 insertions(+), 10 deletions(-) create mode 100644 pkg/pipeline/pipelinemgr.py diff --git a/pkg/api/http/service/pipeline.py b/pkg/api/http/service/pipeline.py index 22f9f4a9..7920c4c9 100644 --- a/pkg/api/http/service/pipeline.py +++ b/pkg/api/http/service/pipeline.py @@ -56,7 +56,10 @@ class PipelineService: await self.ap.persistence_mgr.execute_async( sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data) ) - # TODO: 更新到pipeline manager + + pipeline = await self.get_pipeline(pipeline_data['uuid']) + + await self.ap.pipeline_mgr.load_pipeline(pipeline) return pipeline_data['uuid'] @@ -67,10 +70,15 @@ class PipelineService: await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data) ) - # TODO: 更新到pipeline manager + + await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid) + + pipeline = await self.get_pipeline(pipeline_uuid) + + await self.ap.pipeline_mgr.load_pipeline(pipeline) async def delete_pipeline(self, pipeline_uuid: str) -> None: await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid) ) - # TODO: 更新到pipeline manager + await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid) diff --git a/pkg/core/app.py b/pkg/core/app.py index f0884d81..0191cc02 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -20,7 +20,7 @@ from ..audit.center import v2 as center_mgr from ..command import cmdmgr from ..plugin import manager as plugin_mgr from ..pipeline import pool -from ..pipeline import controller, stagemgr +from ..pipeline import controller, stagemgr, pipelinemgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..persistence import mgr as persistencemgr from ..api.http.controller import main as http_controller @@ -102,6 +102,8 @@ class Application: stage_mgr: stagemgr.StageManager = None + pipeline_mgr: pipelinemgr.PipelineManager = None + ver_mgr: version_mgr.VersionManager = None ann_mgr: announce_mgr.AnnouncementManager = None diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 3873d719..0bd0d8a5 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -6,7 +6,7 @@ from .. import stage, app from ...utils import version, proxy, announce, platform from ...audit.center import v2 as center_v2 from ...audit import identifier -from ...pipeline import pool, controller, stagemgr +from ...pipeline import pool, controller, stagemgr, pipelinemgr from ...plugin import manager as plugin_mgr from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr @@ -119,6 +119,10 @@ class BuildAppStage(stage.BootingStage): await stage_mgr.initialize() ap.stage_mgr = stage_mgr + pipeline_mgr = pipelinemgr.PipelineManager(ap) + await pipeline_mgr.initialize() + ap.pipeline_mgr = pipeline_mgr + http_ctrl = http_controller.HTTPController(ap) await http_ctrl.initialize() ap.http_ctrl = http_ctrl diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py new file mode 100644 index 00000000..a805e5cd --- /dev/null +++ b/pkg/pipeline/pipelinemgr.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import typing + +import sqlalchemy + +from ..core import app, entities +from ..entity.persistence import pipeline as persistence_pipeline +from . import stagemgr, stage + + +class RuntimePipeline: + """运行时流水线""" + + ap: app.Application + + pipeline_entity: persistence_pipeline.LegacyPipeline + """流水线实体""" + + stage_containers: list[stagemgr.StageInstContainer] + """阶段实例容器""" + + def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]): + self.ap = ap + self.pipeline_entity = pipeline_entity + self.stage_containers = stage_containers + + async def run(self): + pass + + +class PipelineManager: + """流水线管理器""" + + # ====== 4.0 ====== + + ap: app.Application + + pipelines: list[RuntimePipeline] + + stage_dict: dict[str, type[stage.PipelineStage]] + + def __init__(self, ap: app.Application): + self.ap = ap + self.pipelines = [] + + async def initialize(self): + self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()} + + await self.load_pipelines_from_db() + + async def load_pipelines_from_db(self): + self.ap.logger.info('Loading pipelines from db...') + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_pipeline.LegacyPipeline) + ) + + pipelines = result.all() + + # load pipelines + for pipeline in pipelines: + await self.load_pipeline(pipeline) + + async def load_pipeline(self, pipeline_entity: persistence_pipeline.LegacyPipeline | sqlalchemy.Row[persistence_pipeline.LegacyPipeline] | dict): + + if isinstance(pipeline_entity, sqlalchemy.Row): + pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping) + elif isinstance(pipeline_entity, dict): + pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity) + + # initialize stage containers according to pipeline_entity.stages + stage_containers = [] + for stage_name in pipeline_entity.stages: + stage_containers.append(stagemgr.StageInstContainer( + stage_name=stage_name, + stage_class=self.stage_dict[stage_name] + )) + + runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers) + self.pipelines.append(runtime_pipeline) + + async def get_pipeline_by_uuid(self, uuid: str) -> RuntimePipeline | None: + for pipeline in self.pipelines: + if pipeline.pipeline_entity.uuid == uuid: + return pipeline + return None + + async def remove_pipeline(self, uuid: str): + for pipeline in self.pipelines: + if pipeline.pipeline_entity.uuid == uuid: + self.pipelines.remove(pipeline) + return \ No newline at end of file diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py index 56c092b5..206f2bdf 100644 --- a/pkg/pipeline/stage.py +++ b/pkg/pipeline/stage.py @@ -7,13 +7,13 @@ from ..core import app, entities as core_entities from . import entities -_stage_classes: dict[str, PipelineStage] = {} +preregistered_stages: dict[str, PipelineStage] = {} def stage_class(name: str): def decorator(cls): - _stage_classes[name] = cls + preregistered_stages[name] = cls return cls return decorator diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index 2bd685d6..19fce2d6 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -58,7 +58,7 @@ class StageManager: """初始化 """ - for name, cls in stage._stage_classes.items(): + for name, cls in stage.preregistered_stages.items(): self.stage_containers.append(StageInstContainer( inst_name=name, inst=cls(self.ap) diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index e8329ac1..7db7a040 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -72,10 +72,12 @@ class ModelManager: self.requester_dict = requester_dict - await self.load_model_from_db() + await self.load_models_from_db() - async def load_model_from_db(self): + async def load_models_from_db(self): """从数据库加载模型""" + self.ap.logger.info('Loading models from db...') + self.llm_models = [] # llm models