From 81481c9050cbfe7042258937a386f68ffe262ef0 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Mon, 17 Mar 2025 22:03:09 +0800 Subject: [PATCH] feat: new llm initialization logic --- pkg/command/operators/model.py | 2 +- pkg/provider/modelmgr/modelmgr.py | 39 +++++++++++++++++++++++++++--- pkg/provider/modelmgr/requester.py | 12 ++++++++- pkg/provider/modelmgr/token.py | 6 ++--- 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/pkg/command/operators/model.py b/pkg/command/operators/model.py index 692e2728..f46c9590 100644 --- a/pkg/command/operators/model.py +++ b/pkg/command/operators/model.py @@ -54,7 +54,7 @@ class ModelShowOperator(operator.CommandOperator): if model.model_name is not None: content += f"请求模型名称: {model.model_name}\n" content += f"请求器: {model.requester.name}\n" - content += f"密钥组: {model.token_mgr.provider}\n" + content += f"密钥组: {model.token_mgr.name}\n" content += f"支持视觉: {model.vision_supported}\n" content += f"支持工具: {model.tool_call_supported}\n" diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 1884da4b..cafaf7e9 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -1,12 +1,13 @@ from __future__ import annotations import aiohttp +import sqlalchemy from . import entities, requester from ...core import app from ...discover import engine from . import token -from ...entity.persistence import model +from ...entity.persistence import model as persistence_model from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl, volcarkchatcmpl, xaichatcmpl, zhipuaichatcmpl, lmstudiochatcmpl, siliconflowchatcmpl, volcarkchatcmpl FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" @@ -15,7 +16,7 @@ FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_lis class RuntimeLLMModel: """运行时模型""" - model_entity: model.LLMModel + model_entity: persistence_model.LLMModel """模型数据""" token_mgr: token.TokenManager @@ -24,7 +25,7 @@ class RuntimeLLMModel: requester: requester.LLMAPIRequester """请求器实例""" - def __init__(self, model_entity: model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester): + def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester): self.model_entity = model_entity self.token_mgr = token_mgr self.requester = requester @@ -141,6 +142,38 @@ class ModelManager: except Exception as e: self.ap.logger.error(f"初始化模型 {model['name']} 失败: {type(e)} {e} ,请检查配置文件") + async def load_model_from_db(self): + """从数据库加载模型""" + self.llm_models = [] + + # forge requester class dict + requester_dict: dict[str, type[requester.LLMAPIRequester]] = {} + for component in self.requester_components: + requester_dict[component.metadata.name] = component.get_python_component_class() + + # llm models + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.LLMModel) + ) + + llm_models = result.all() + + # load models + for llm_model in llm_models: + assert isinstance(llm_model, persistence_model.LLMModel) + runtime_llm_model = RuntimeLLMModel( + model_entity=llm_model, + token_mgr=token.TokenManager( + name=llm_model.uuid, + tokens=llm_model.api_keys + ), + requester=requester_dict[llm_model.requester]( + ap=self.ap, + config=llm_model.requester_config + ) + ) + self.llm_models.append(runtime_llm_model) + def get_available_requesters_info(self) -> list[dict]: """获取所有可用的请求器""" return [ diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 147a97c4..7f13c58b 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -17,8 +17,16 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): ap: app.Application - def __init__(self, ap: app.Application): + default_config: dict[str, typing.Any] = {} + + requester_cfg: dict[str, typing.Any] = {} + + def __init__(self, ap: app.Application, config: dict[str, typing.Any]): self.ap = ap + self.requester_cfg = { + **self.default_config + } + self.requester_cfg.update(config) async def initialize(self): pass @@ -40,6 +48,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): model: modelmgr_entities.LLMModelInfo, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: """调用API @@ -47,6 +56,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): model (modelmgr_entities.LLMModelInfo): 使用的模型信息 messages (typing.List[llm_entities.Message]): 消息对象列表 funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. + extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: llm_entities.Message: 返回消息对象 diff --git a/pkg/provider/modelmgr/token.py b/pkg/provider/modelmgr/token.py index f6f9436d..eeec6986 100644 --- a/pkg/provider/modelmgr/token.py +++ b/pkg/provider/modelmgr/token.py @@ -7,14 +7,14 @@ class TokenManager(): """鉴权 Token 管理器 """ - provider: str + name: str tokens: list[str] using_token_index: typing.Optional[int] = 0 - def __init__(self, provider: str, tokens: list[str]): - self.provider = provider + def __init__(self, name: str, tokens: list[str]): + self.name = name self.tokens = tokens self.using_token_index = 0