diff --git a/pkg/api/http/controller/groups/provider/models.py b/pkg/api/http/controller/groups/provider/models.py index 683fac01..bb77986c 100644 --- a/pkg/api/http/controller/groups/provider/models.py +++ b/pkg/api/http/controller/groups/provider/models.py @@ -36,3 +36,11 @@ class LLMModelsRouterGroup(group.RouterGroup): await self.ap.model_service.delete_llm_model(model_uuid) return self.success() + + @self.route('//test', methods=['POST']) + async def _(model_uuid: str) -> str: + json_data = await quart.request.json + + await self.ap.model_service.test_llm_model(model_uuid, json_data) + + return self.success() diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 080abb9d..74fb4e02 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -6,6 +6,8 @@ import sqlalchemy from ....core import app from ....entity.persistence import model as persistence_model from ....entity.persistence import pipeline as persistence_pipeline +from ....provider.modelmgr import requester as model_requester +from ....provider import entities as llm_entities class ModelsService: @@ -78,3 +80,26 @@ class ModelsService: ) await self.ap.model_mgr.remove_llm_model(model_uuid) + + async def test_llm_model(self, model_uuid: str, model_data: dict) -> None: + runtime_llm_model: model_requester.RuntimeLLMModel | None = None + + if model_uuid != '_': + for model in self.ap.model_mgr.llm_models: + if model.model_entity.uuid == model_uuid: + runtime_llm_model = model + break + + if runtime_llm_model is None: + raise Exception('model not found') + + else: + runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data) + + await runtime_llm_model.requester.invoke_llm( + query=None, + model=runtime_llm_model, + messages=[llm_entities.Message(role='user', content='Hello, world!')], + funcs=[], + extra_args={}, + ) diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 6e50fc37..6bc80fe3 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -4,9 +4,6 @@ import sqlalchemy from . import entities, requester from ...core import app -from ...core import entities as core_entities -from .. import entities as llm_entities -from ..tools import entities as tools_entities from ...discover import engine from . import token from ...entity.persistence import model as persistence_model @@ -69,12 +66,11 @@ class ModelManager: for llm_model in llm_models: await self.load_llm_model(llm_model) - async def load_llm_model( + async def init_runtime_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """加载模型""" - + """初始化运行时模型""" if isinstance(model_info, sqlalchemy.Row): model_info = persistence_model.LLMModel(**model_info._mapping) elif isinstance(model_info, dict): @@ -92,6 +88,15 @@ class ModelManager: ), requester=requester_inst, ) + + return runtime_llm_model + + async def load_llm_model( + self, + model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, + ): + """加载模型""" + runtime_llm_model = await self.init_runtime_llm_model(model_info) self.llm_models.append(runtime_llm_model) async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated @@ -132,12 +137,3 @@ class ModelManager: if component.metadata.name == name: return component return None - - async def invoke_llm( - self, - query: core_entities.Query, - model_uuid: str, - messages: list[llm_entities.Message], - funcs: list[tools_entities.LLMFunction] = None, - ) -> llm_entities.Message: - pass diff --git a/web/src/app/home/models/component/llm-form/LLMForm.tsx b/web/src/app/home/models/component/llm-form/LLMForm.tsx index be88fdef..0e86dd51 100644 --- a/web/src/app/home/models/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/models/component/llm-form/LLMForm.tsx @@ -130,6 +130,7 @@ export default function LLMForm({ const [requesterDefaultURLList, setRequesterDefaultURLList] = useState< string[] >([]); + const [modelTesting, setModelTesting] = useState(false); useEffect(() => { initLLMModelFormComponent().then(() => { @@ -308,6 +309,34 @@ export default function LLMForm({ } } + function testLLMModelInForm() { + setModelTesting(true); + httpClient + .testLLMModel('_', { + uuid: '', + name: form.getValues('name'), + description: '', + requester: form.getValues('model_provider'), + requester_config: { + base_url: form.getValues('url'), + timeout: 120, + }, + api_keys: [form.getValues('api_key')], + abilities: form.getValues('abilities'), + extra_args: form.getValues('extra_args'), + }) + .then((res) => { + console.log(res); + toast.success(t('models.testSuccess')); + }) + .catch(() => { + toast.error(t('models.testError')); + }) + .finally(() => { + setModelTesting(false); + }); + } + return (
+ +