feat: model sync between api and manager layer

This commit is contained in:
Junyan Qin
2025-03-19 23:58:14 +08:00
parent 81481c9050
commit 4275459d45
2 changed files with 47 additions and 94 deletions

View File

@@ -35,7 +35,7 @@ class ModelsService:
**model_data
)
)
# TODO: add to runtime
await self.ap.model_mgr.load_model(model_data)
async def get_llm_model(self, model_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
@@ -54,8 +54,12 @@ class ModelsService:
sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data)
)
await self.ap.model_mgr.remove_model(model_uuid)
await self.ap.model_mgr.load_model(model_data)
async def delete_llm_model(self, model_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
)
# TODO: delete from runtime
await self.ap.model_mgr.remove_model(model_uuid)

View File

@@ -47,14 +47,17 @@ class ModelManager:
llm_models: list[RuntimeLLMModel]
requester_components: list[engine.Component]
requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache
def __init__(self, ap: app.Application):
self.ap = ap
self.requester_components = []
self.model_list = []
self.requesters = {}
self.token_mgrs = {}
self.llm_models = []
self.requester_components = []
self.requester_dict = {}
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
@@ -65,92 +68,21 @@ class ModelManager:
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
async def initialize(self):
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
# 初始化token_mgr, requester
for k, v in self.ap.provider_cfg.data['keys'].items():
self.token_mgrs[k] = token.TokenManager(k, v)
for component in self.requester_components:
api_cls = component.get_python_component_class()
api_inst = api_cls(self.ap)
await api_inst.initialize()
self.requesters[component.metadata.name] = api_inst
# 尝试从api获取最新的模型信息
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method="GET",
url=FETCH_MODEL_LIST_URL,
# 参数
params={
"version": self.ap.ver_mgr.get_current_version()
},
) as resp:
model_list = (await resp.json())['data']['list']
for model in model_list:
for index, local_model in enumerate(self.ap.llm_models_meta.data['list']):
if model['name'] == local_model['name']:
self.ap.llm_models_meta.data['list'][index] = model
break
else:
self.ap.llm_models_meta.data['list'].append(model)
await self.ap.llm_models_meta.dump_config()
except Exception as e:
self.ap.logger.debug(f'获取最新模型列表失败: {e}')
default_model_info: entities.LLMModelInfo = None
for model in self.ap.llm_models_meta.data['list']:
if model['name'] == 'default':
default_model_info = entities.LLMModelInfo(
name=model['name'],
model_name=None,
token_mgr=self.token_mgrs[model['token_mgr']],
requester=self.requesters[model['requester']],
tool_call_supported=model['tool_call_supported'],
vision_supported=model['vision_supported']
)
break
for model in self.ap.llm_models_meta.data['list']:
try:
model_name = model.get('model_name', default_model_info.model_name)
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
req = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
vision_supported = model.get('vision_supported', default_model_info.vision_supported)
model_info = entities.LLMModelInfo(
name=model['name'],
model_name=model_name,
token_mgr=token_mgr,
requester=req,
tool_call_supported=tool_call_supported,
vision_supported=vision_supported
)
self.model_list.append(model_info)
except Exception as e:
self.ap.logger.error(f"初始化模型 {model['name']} 失败: {type(e)} {e} ,请检查配置文件")
async def load_model_from_db(self):
"""从数据库加载模型"""
self.llm_models = []
# forge requester class dict
requester_dict: dict[str, type[requester.LLMAPIRequester]] = {}
for component in self.requester_components:
requester_dict[component.metadata.name] = component.get_python_component_class()
self.requester_dict = requester_dict
await self.load_model_from_db()
async def load_model_from_db(self):
"""从数据库加载模型"""
self.llm_models = []
# llm models
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel)
@@ -160,19 +92,36 @@ class ModelManager:
# load models
for llm_model in llm_models:
assert isinstance(llm_model, persistence_model.LLMModel)
runtime_llm_model = RuntimeLLMModel(
model_entity=llm_model,
token_mgr=token.TokenManager(
name=llm_model.uuid,
tokens=llm_model.api_keys
),
requester=requester_dict[llm_model.requester](
ap=self.ap,
config=llm_model.requester_config
)
await self.load_model(llm_model)
async def load_model(self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict):
"""加载模型"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.LLMModel(**model_info._mapping)
elif isinstance(model_info, dict):
model_info = persistence_model.LLMModel(**model_info)
runtime_llm_model = RuntimeLLMModel(
model_entity=model_info,
token_mgr=token.TokenManager(
name=model_info.uuid,
tokens=model_info.api_keys,
),
requester=self.requester_dict[model_info.requester](
ap=self.ap,
config=model_info.requester_config
)
self.llm_models.append(runtime_llm_model)
)
print(runtime_llm_model, runtime_llm_model.model_entity.name, "loaded")
self.llm_models.append(runtime_llm_model)
async def remove_model(self, model_uuid: str):
"""移除模型"""
for model in self.llm_models:
if model.model_entity.uuid == model_uuid:
self.llm_models.remove(model)
return
def get_available_requesters_info(self) -> list[dict]:
"""获取所有可用的请求器"""