mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: llmmodels crud
This commit is contained in:
@@ -4,6 +4,7 @@ import abc
|
||||
import typing
|
||||
import enum
|
||||
import quart
|
||||
import traceback
|
||||
from quart.typing import RouteCallable
|
||||
|
||||
from ....core import app
|
||||
@@ -75,7 +76,9 @@ class RouterGroup(abc.ABC):
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
except Exception as e: # 自动 500
|
||||
return self.http_status(500, -2, str(e))
|
||||
traceback.print_exc()
|
||||
# return self.http_status(500, -2, str(e))
|
||||
return self.http_status(500, -2, 'internal server error')
|
||||
|
||||
new_f = handler_error
|
||||
new_f.__name__ = (self.name + rule).replace('/', '__')
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import quart
|
||||
import uuid
|
||||
|
||||
from .. import group
|
||||
from .....entity.persistence import model
|
||||
@@ -15,13 +16,28 @@ class LLMModelsRouterGroup(group.RouterGroup):
|
||||
'models': await self.ap.model_service.get_llm_models()
|
||||
})
|
||||
elif quart.request.method == 'POST':
|
||||
pass
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.model_service.create_llm_model(json_data)
|
||||
|
||||
return self.success()
|
||||
|
||||
@self.route('/<model_uuid>', methods=['GET', 'PUT', 'DELETE'])
|
||||
async def _(model_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
pass
|
||||
model = await self.ap.model_service.get_llm_model(model_uuid)
|
||||
|
||||
if model is None:
|
||||
return self.http_status(404, -1, 'model not found')
|
||||
|
||||
return self.success(data=model)
|
||||
elif quart.request.method == 'PUT':
|
||||
pass
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.model_service.update_llm_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
elif quart.request.method == 'DELETE':
|
||||
pass
|
||||
await self.ap.model_service.delete_llm_model(model_uuid)
|
||||
|
||||
return self.success()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import datetime
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import model
|
||||
from ....entity.persistence import model as persistence_model
|
||||
|
||||
|
||||
class ModelsService:
|
||||
@@ -13,23 +15,47 @@ class ModelsService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_llm_models(self) -> list[model.LLMModel]:
|
||||
async def get_llm_models(self) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(model.LLMModel)
|
||||
sqlalchemy.select(persistence_model.LLMModel)
|
||||
)
|
||||
|
||||
result_list = result.all()
|
||||
|
||||
return result_list
|
||||
models = result.all()
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
||||
for model in models
|
||||
]
|
||||
|
||||
async def create_llm_model(self, model: model.LLMModel) -> None:
|
||||
pass
|
||||
async def create_llm_model(self, model_data: dict) -> None:
|
||||
|
||||
async def get_llm_model(self, model_uuid: str) -> model.LLMModel:
|
||||
pass
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
async def update_llm_model(self, model: model.LLMModel) -> None:
|
||||
pass
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.LLMModel).values(
|
||||
**model_data
|
||||
)
|
||||
)
|
||||
# TODO: add to runtime
|
||||
|
||||
async def get_llm_model(self, model_uuid: str) -> dict | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
||||
)
|
||||
|
||||
model = result.first()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
||||
|
||||
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data)
|
||||
)
|
||||
|
||||
async def delete_llm_model(self, model_uuid: str) -> None:
|
||||
pass
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
||||
)
|
||||
# TODO: delete from runtime
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import typing
|
||||
|
||||
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
|
||||
import sqlalchemy
|
||||
@@ -55,3 +56,9 @@ class PersistenceManager:
|
||||
|
||||
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
|
||||
return self.db.get_engine()
|
||||
|
||||
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
|
||||
return {
|
||||
column.name: getattr(data, column.name) if not isinstance(getattr(data, column.name), (datetime.datetime)) else getattr(data, column.name).isoformat()
|
||||
for column in model.__table__.columns
|
||||
}
|
||||
|
||||
@@ -12,6 +12,24 @@ from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcm
|
||||
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
||||
|
||||
|
||||
class RuntimeLLMModel:
|
||||
"""运行时模型"""
|
||||
|
||||
model_entity: model.LLMModel
|
||||
"""模型数据"""
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
"""api key管理器"""
|
||||
|
||||
requester: requester.LLMAPIRequester
|
||||
"""请求器实例"""
|
||||
|
||||
def __init__(self, model_entity: model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester):
|
||||
self.model_entity = model_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""模型管理器"""
|
||||
|
||||
@@ -25,7 +43,9 @@ class ModelManager:
|
||||
|
||||
token_mgrs: dict[str, token.TokenManager]
|
||||
|
||||
models: list[model.LLMModel]
|
||||
# ====== 4.0 ======
|
||||
|
||||
llm_models: list[RuntimeLLMModel]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
@@ -33,7 +53,7 @@ class ModelManager:
|
||||
self.model_list = []
|
||||
self.requesters = {}
|
||||
self.token_mgrs = {}
|
||||
self.models = []
|
||||
self.llm_models = []
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
||||
"""通过名称获取模型
|
||||
|
||||
Reference in New Issue
Block a user