diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index a68999fd..aff1a63a 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -35,7 +35,7 @@ class ModelsService: **model_data ) ) - # TODO: add to runtime + await self.ap.model_mgr.load_model(model_data) async def get_llm_model(self, model_uuid: str) -> dict | None: result = await self.ap.persistence_mgr.execute_async( @@ -54,8 +54,12 @@ class ModelsService: sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data) ) + await self.ap.model_mgr.remove_model(model_uuid) + await self.ap.model_mgr.load_model(model_data) + async def delete_llm_model(self, model_uuid: str) -> None: await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) ) - # TODO: delete from runtime \ No newline at end of file + + await self.ap.model_mgr.remove_model(model_uuid) diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index cafaf7e9..1eb3d776 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -47,14 +47,17 @@ class ModelManager: llm_models: list[RuntimeLLMModel] requester_components: list[engine.Component] + + requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache def __init__(self, ap: app.Application): self.ap = ap - self.requester_components = [] self.model_list = [] self.requesters = {} self.token_mgrs = {} self.llm_models = [] + self.requester_components = [] + self.requester_dict = {} async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: """通过名称获取模型 @@ -65,92 +68,21 @@ class ModelManager: raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置") async def initialize(self): - self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') - # 初始化token_mgr, requester - for k, v in self.ap.provider_cfg.data['keys'].items(): - self.token_mgrs[k] = token.TokenManager(k, v) - - for component in self.requester_components: - api_cls = component.get_python_component_class() - api_inst = api_cls(self.ap) - await api_inst.initialize() - self.requesters[component.metadata.name] = api_inst - - # 尝试从api获取最新的模型信息 - try: - async with aiohttp.ClientSession() as session: - async with session.request( - method="GET", - url=FETCH_MODEL_LIST_URL, - # 参数 - params={ - "version": self.ap.ver_mgr.get_current_version() - }, - ) as resp: - model_list = (await resp.json())['data']['list'] - - for model in model_list: - - for index, local_model in enumerate(self.ap.llm_models_meta.data['list']): - if model['name'] == local_model['name']: - self.ap.llm_models_meta.data['list'][index] = model - break - else: - self.ap.llm_models_meta.data['list'].append(model) - - await self.ap.llm_models_meta.dump_config() - - except Exception as e: - self.ap.logger.debug(f'获取最新模型列表失败: {e}') - - default_model_info: entities.LLMModelInfo = None - - for model in self.ap.llm_models_meta.data['list']: - if model['name'] == 'default': - default_model_info = entities.LLMModelInfo( - name=model['name'], - model_name=None, - token_mgr=self.token_mgrs[model['token_mgr']], - requester=self.requesters[model['requester']], - tool_call_supported=model['tool_call_supported'], - vision_supported=model['vision_supported'] - ) - break - - for model in self.ap.llm_models_meta.data['list']: - - try: - - model_name = model.get('model_name', default_model_info.model_name) - token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr - req = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester - tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported) - vision_supported = model.get('vision_supported', default_model_info.vision_supported) - - model_info = entities.LLMModelInfo( - name=model['name'], - model_name=model_name, - token_mgr=token_mgr, - requester=req, - tool_call_supported=tool_call_supported, - vision_supported=vision_supported - ) - self.model_list.append(model_info) - - except Exception as e: - self.ap.logger.error(f"初始化模型 {model['name']} 失败: {type(e)} {e} ,请检查配置文件") - - async def load_model_from_db(self): - """从数据库加载模型""" - self.llm_models = [] - # forge requester class dict requester_dict: dict[str, type[requester.LLMAPIRequester]] = {} for component in self.requester_components: requester_dict[component.metadata.name] = component.get_python_component_class() + self.requester_dict = requester_dict + + await self.load_model_from_db() + + async def load_model_from_db(self): + """从数据库加载模型""" + self.llm_models = [] + # llm models result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_model.LLMModel) @@ -160,19 +92,36 @@ class ModelManager: # load models for llm_model in llm_models: - assert isinstance(llm_model, persistence_model.LLMModel) - runtime_llm_model = RuntimeLLMModel( - model_entity=llm_model, - token_mgr=token.TokenManager( - name=llm_model.uuid, - tokens=llm_model.api_keys - ), - requester=requester_dict[llm_model.requester]( - ap=self.ap, - config=llm_model.requester_config - ) + await self.load_model(llm_model) + + async def load_model(self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict): + """加载模型""" + + if isinstance(model_info, sqlalchemy.Row): + model_info = persistence_model.LLMModel(**model_info._mapping) + elif isinstance(model_info, dict): + model_info = persistence_model.LLMModel(**model_info) + + runtime_llm_model = RuntimeLLMModel( + model_entity=model_info, + token_mgr=token.TokenManager( + name=model_info.uuid, + tokens=model_info.api_keys, + ), + requester=self.requester_dict[model_info.requester]( + ap=self.ap, + config=model_info.requester_config ) - self.llm_models.append(runtime_llm_model) + ) + print(runtime_llm_model, runtime_llm_model.model_entity.name, "loaded") + self.llm_models.append(runtime_llm_model) + + async def remove_model(self, model_uuid: str): + """移除模型""" + for model in self.llm_models: + if model.model_entity.uuid == model_uuid: + self.llm_models.remove(model) + return def get_available_requesters_info(self) -> list[dict]: """获取所有可用的请求器"""