diff --git a/pkg/api/http/controller/group.py b/pkg/api/http/controller/group.py index 5a6ab97e..7186802f 100644 --- a/pkg/api/http/controller/group.py +++ b/pkg/api/http/controller/group.py @@ -4,6 +4,7 @@ import abc import typing import enum import quart +import traceback from quart.typing import RouteCallable from ....core import app @@ -75,7 +76,9 @@ class RouterGroup(abc.ABC): try: return await f(*args, **kwargs) except Exception as e: # 自动 500 - return self.http_status(500, -2, str(e)) + traceback.print_exc() + # return self.http_status(500, -2, str(e)) + return self.http_status(500, -2, 'internal server error') new_f = handler_error new_f.__name__ = (self.name + rule).replace('/', '__') diff --git a/pkg/api/http/controller/groups/models.py b/pkg/api/http/controller/groups/models.py index e8921d5f..d88df866 100644 --- a/pkg/api/http/controller/groups/models.py +++ b/pkg/api/http/controller/groups/models.py @@ -1,4 +1,5 @@ import quart +import uuid from .. import group from .....entity.persistence import model @@ -15,13 +16,28 @@ class LLMModelsRouterGroup(group.RouterGroup): 'models': await self.ap.model_service.get_llm_models() }) elif quart.request.method == 'POST': - pass + json_data = await quart.request.json + + await self.ap.model_service.create_llm_model(json_data) + + return self.success() @self.route('/', methods=['GET', 'PUT', 'DELETE']) async def _(model_uuid: str) -> str: if quart.request.method == 'GET': - pass + model = await self.ap.model_service.get_llm_model(model_uuid) + + if model is None: + return self.http_status(404, -1, 'model not found') + + return self.success(data=model) elif quart.request.method == 'PUT': - pass + json_data = await quart.request.json + + await self.ap.model_service.update_llm_model(model_uuid, json_data) + + return self.success() elif quart.request.method == 'DELETE': - pass \ No newline at end of file + await self.ap.model_service.delete_llm_model(model_uuid) + + return self.success() diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 729b815c..a68999fd 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -1,9 +1,11 @@ from __future__ import annotations +import uuid +import datetime import sqlalchemy from ....core import app -from ....entity.persistence import model +from ....entity.persistence import model as persistence_model class ModelsService: @@ -13,23 +15,47 @@ class ModelsService: def __init__(self, ap: app.Application) -> None: self.ap = ap - async def get_llm_models(self) -> list[model.LLMModel]: + async def get_llm_models(self) -> list[dict]: result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(model.LLMModel) + sqlalchemy.select(persistence_model.LLMModel) ) - result_list = result.all() - - return result_list + models = result.all() + return [ + self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) + for model in models + ] - async def create_llm_model(self, model: model.LLMModel) -> None: - pass + async def create_llm_model(self, model_data: dict) -> None: - async def get_llm_model(self, model_uuid: str) -> model.LLMModel: - pass + model_data['uuid'] = str(uuid.uuid4()) - async def update_llm_model(self, model: model.LLMModel) -> None: - pass + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_model.LLMModel).values( + **model_data + ) + ) + # TODO: add to runtime + + async def get_llm_model(self, model_uuid: str) -> dict | None: + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) + ) + + model = result.first() + + if model is None: + return None + + return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) + + async def update_llm_model(self, model_uuid: str, model_data: dict) -> None: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data) + ) async def delete_llm_model(self, model_uuid: str) -> None: - pass + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) + ) + # TODO: delete from runtime \ No newline at end of file diff --git a/pkg/persistence/mgr.py b/pkg/persistence/mgr.py index 27ffe522..67b618e7 100644 --- a/pkg/persistence/mgr.py +++ b/pkg/persistence/mgr.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import datetime +import typing import sqlalchemy.ext.asyncio as sqlalchemy_asyncio import sqlalchemy @@ -55,3 +56,9 @@ class PersistenceManager: def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine: return self.db.get_engine() + + def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict: + return { + column.name: getattr(data, column.name) if not isinstance(getattr(data, column.name), (datetime.datetime)) else getattr(data, column.name).isoformat() + for column in model.__table__.columns + } diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 3d86dfad..2951d9ab 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -12,6 +12,24 @@ from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcm FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" +class RuntimeLLMModel: + """运行时模型""" + + model_entity: model.LLMModel + """模型数据""" + + token_mgr: token.TokenManager + """api key管理器""" + + requester: requester.LLMAPIRequester + """请求器实例""" + + def __init__(self, model_entity: model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester): + self.model_entity = model_entity + self.token_mgr = token_mgr + self.requester = requester + + class ModelManager: """模型管理器""" @@ -25,7 +43,9 @@ class ModelManager: token_mgrs: dict[str, token.TokenManager] - models: list[model.LLMModel] + # ====== 4.0 ====== + + llm_models: list[RuntimeLLMModel] def __init__(self, ap: app.Application): self.ap = ap @@ -33,7 +53,7 @@ class ModelManager: self.model_list = [] self.requesters = {} self.token_mgrs = {} - self.models = [] + self.llm_models = [] async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: """通过名称获取模型