feat: add supports for testing llm models (#1454)

* feat: add supports for testing llm models

* fix: linter error
This commit is contained in:
Junyan Qin (Chin)
2025-05-19 23:10:04 +08:00
committed by GitHub
parent aba51409a7
commit a7d2a68639
7 changed files with 92 additions and 15 deletions

View File

@@ -36,3 +36,11 @@ class LLMModelsRouterGroup(group.RouterGroup):
await self.ap.model_service.delete_llm_model(model_uuid)
return self.success()
@self.route('/<model_uuid>/test', methods=['POST'])
async def _(model_uuid: str) -> str:
json_data = await quart.request.json
await self.ap.model_service.test_llm_model(model_uuid, json_data)
return self.success()

View File

@@ -6,6 +6,8 @@ import sqlalchemy
from ....core import app
from ....entity.persistence import model as persistence_model
from ....entity.persistence import pipeline as persistence_pipeline
from ....provider.modelmgr import requester as model_requester
from ....provider import entities as llm_entities
class ModelsService:
@@ -78,3 +80,26 @@ class ModelsService:
)
await self.ap.model_mgr.remove_llm_model(model_uuid)
async def test_llm_model(self, model_uuid: str, model_data: dict) -> None:
runtime_llm_model: model_requester.RuntimeLLMModel | None = None
if model_uuid != '_':
for model in self.ap.model_mgr.llm_models:
if model.model_entity.uuid == model_uuid:
runtime_llm_model = model
break
if runtime_llm_model is None:
raise Exception('model not found')
else:
runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data)
await runtime_llm_model.requester.invoke_llm(
query=None,
model=runtime_llm_model,
messages=[llm_entities.Message(role='user', content='Hello, world!')],
funcs=[],
extra_args={},
)

View File

@@ -4,9 +4,6 @@ import sqlalchemy
from . import entities, requester
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from ..tools import entities as tools_entities
from ...discover import engine
from . import token
from ...entity.persistence import model as persistence_model
@@ -69,12 +66,11 @@ class ModelManager:
for llm_model in llm_models:
await self.load_llm_model(llm_model)
async def load_llm_model(
async def init_runtime_llm_model(
self,
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
):
"""加载模型"""
"""初始化运行时模型"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.LLMModel(**model_info._mapping)
elif isinstance(model_info, dict):
@@ -92,6 +88,15 @@ class ModelManager:
),
requester=requester_inst,
)
return runtime_llm_model
async def load_llm_model(
self,
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
):
"""加载模型"""
runtime_llm_model = await self.init_runtime_llm_model(model_info)
self.llm_models.append(runtime_llm_model)
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated
@@ -132,12 +137,3 @@ class ModelManager:
if component.metadata.name == name:
return component
return None
async def invoke_llm(
self,
query: core_entities.Query,
model_uuid: str,
messages: list[llm_entities.Message],
funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
pass