mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-04 04:54:36 +00:00
feat: model sync between api and manager layer
This commit is contained in:
@@ -35,7 +35,7 @@ class ModelsService:
|
||||
**model_data
|
||||
)
|
||||
)
|
||||
# TODO: add to runtime
|
||||
await self.ap.model_mgr.load_model(model_data)
|
||||
|
||||
async def get_llm_model(self, model_uuid: str) -> dict | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
@@ -54,8 +54,12 @@ class ModelsService:
|
||||
sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data)
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_model(model_uuid)
|
||||
await self.ap.model_mgr.load_model(model_data)
|
||||
|
||||
async def delete_llm_model(self, model_uuid: str) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
||||
)
|
||||
# TODO: delete from runtime
|
||||
|
||||
await self.ap.model_mgr.remove_model(model_uuid)
|
||||
|
||||
@@ -47,14 +47,17 @@ class ModelManager:
|
||||
llm_models: list[RuntimeLLMModel]
|
||||
|
||||
requester_components: list[engine.Component]
|
||||
|
||||
requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.requester_components = []
|
||||
self.model_list = []
|
||||
self.requesters = {}
|
||||
self.token_mgrs = {}
|
||||
self.llm_models = []
|
||||
self.requester_components = []
|
||||
self.requester_dict = {}
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
||||
"""通过名称获取模型
|
||||
@@ -65,92 +68,21 @@ class ModelManager:
|
||||
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
|
||||
|
||||
# 初始化token_mgr, requester
|
||||
for k, v in self.ap.provider_cfg.data['keys'].items():
|
||||
self.token_mgrs[k] = token.TokenManager(k, v)
|
||||
|
||||
for component in self.requester_components:
|
||||
api_cls = component.get_python_component_class()
|
||||
api_inst = api_cls(self.ap)
|
||||
await api_inst.initialize()
|
||||
self.requesters[component.metadata.name] = api_inst
|
||||
|
||||
# 尝试从api获取最新的模型信息
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
method="GET",
|
||||
url=FETCH_MODEL_LIST_URL,
|
||||
# 参数
|
||||
params={
|
||||
"version": self.ap.ver_mgr.get_current_version()
|
||||
},
|
||||
) as resp:
|
||||
model_list = (await resp.json())['data']['list']
|
||||
|
||||
for model in model_list:
|
||||
|
||||
for index, local_model in enumerate(self.ap.llm_models_meta.data['list']):
|
||||
if model['name'] == local_model['name']:
|
||||
self.ap.llm_models_meta.data['list'][index] = model
|
||||
break
|
||||
else:
|
||||
self.ap.llm_models_meta.data['list'].append(model)
|
||||
|
||||
await self.ap.llm_models_meta.dump_config()
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.debug(f'获取最新模型列表失败: {e}')
|
||||
|
||||
default_model_info: entities.LLMModelInfo = None
|
||||
|
||||
for model in self.ap.llm_models_meta.data['list']:
|
||||
if model['name'] == 'default':
|
||||
default_model_info = entities.LLMModelInfo(
|
||||
name=model['name'],
|
||||
model_name=None,
|
||||
token_mgr=self.token_mgrs[model['token_mgr']],
|
||||
requester=self.requesters[model['requester']],
|
||||
tool_call_supported=model['tool_call_supported'],
|
||||
vision_supported=model['vision_supported']
|
||||
)
|
||||
break
|
||||
|
||||
for model in self.ap.llm_models_meta.data['list']:
|
||||
|
||||
try:
|
||||
|
||||
model_name = model.get('model_name', default_model_info.model_name)
|
||||
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
|
||||
req = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
|
||||
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
|
||||
vision_supported = model.get('vision_supported', default_model_info.vision_supported)
|
||||
|
||||
model_info = entities.LLMModelInfo(
|
||||
name=model['name'],
|
||||
model_name=model_name,
|
||||
token_mgr=token_mgr,
|
||||
requester=req,
|
||||
tool_call_supported=tool_call_supported,
|
||||
vision_supported=vision_supported
|
||||
)
|
||||
self.model_list.append(model_info)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"初始化模型 {model['name']} 失败: {type(e)} {e} ,请检查配置文件")
|
||||
|
||||
async def load_model_from_db(self):
|
||||
"""从数据库加载模型"""
|
||||
self.llm_models = []
|
||||
|
||||
# forge requester class dict
|
||||
requester_dict: dict[str, type[requester.LLMAPIRequester]] = {}
|
||||
for component in self.requester_components:
|
||||
requester_dict[component.metadata.name] = component.get_python_component_class()
|
||||
|
||||
self.requester_dict = requester_dict
|
||||
|
||||
await self.load_model_from_db()
|
||||
|
||||
async def load_model_from_db(self):
|
||||
"""从数据库加载模型"""
|
||||
self.llm_models = []
|
||||
|
||||
# llm models
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.LLMModel)
|
||||
@@ -160,19 +92,36 @@ class ModelManager:
|
||||
|
||||
# load models
|
||||
for llm_model in llm_models:
|
||||
assert isinstance(llm_model, persistence_model.LLMModel)
|
||||
runtime_llm_model = RuntimeLLMModel(
|
||||
model_entity=llm_model,
|
||||
token_mgr=token.TokenManager(
|
||||
name=llm_model.uuid,
|
||||
tokens=llm_model.api_keys
|
||||
),
|
||||
requester=requester_dict[llm_model.requester](
|
||||
ap=self.ap,
|
||||
config=llm_model.requester_config
|
||||
)
|
||||
await self.load_model(llm_model)
|
||||
|
||||
async def load_model(self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict):
|
||||
"""加载模型"""
|
||||
|
||||
if isinstance(model_info, sqlalchemy.Row):
|
||||
model_info = persistence_model.LLMModel(**model_info._mapping)
|
||||
elif isinstance(model_info, dict):
|
||||
model_info = persistence_model.LLMModel(**model_info)
|
||||
|
||||
runtime_llm_model = RuntimeLLMModel(
|
||||
model_entity=model_info,
|
||||
token_mgr=token.TokenManager(
|
||||
name=model_info.uuid,
|
||||
tokens=model_info.api_keys,
|
||||
),
|
||||
requester=self.requester_dict[model_info.requester](
|
||||
ap=self.ap,
|
||||
config=model_info.requester_config
|
||||
)
|
||||
self.llm_models.append(runtime_llm_model)
|
||||
)
|
||||
print(runtime_llm_model, runtime_llm_model.model_entity.name, "loaded")
|
||||
self.llm_models.append(runtime_llm_model)
|
||||
|
||||
async def remove_model(self, model_uuid: str):
|
||||
"""移除模型"""
|
||||
for model in self.llm_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
self.llm_models.remove(model)
|
||||
return
|
||||
|
||||
def get_available_requesters_info(self) -> list[dict]:
|
||||
"""获取所有可用的请求器"""
|
||||
|
||||
Reference in New Issue
Block a user