feat: 用户账户系统

This commit is contained in:
Junyan Qin
2024-11-17 19:11:44 +08:00
parent 036c2182a5
commit 20e3edba8f
23 changed files with 543 additions and 102 deletions

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import abc
import typing
import enum
import quart
from quart.typing import RouteCallable
@@ -23,6 +24,12 @@ def group_class(name: str, path: str) -> None:
return decorator
class AuthType(enum.Enum):
"""认证类型"""
NONE = 'none'
USER_TOKEN = 'user-token'
class RouterGroup(abc.ABC):
name: str
@@ -41,13 +48,30 @@ class RouterGroup(abc.ABC):
async def initialize(self) -> None:
pass
def route(self, rule: str, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
def route(self, rule: str, auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
"""注册一个路由"""
def decorator(f: RouteCallable) -> RouteCallable:
nonlocal rule
rule = self.path + rule
async def handler_error(*args, **kwargs):
if auth_type == AuthType.USER_TOKEN:
# 从Authorization头中获取token
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')
if not token:
return self.http_status(401, -1, '未提供有效的用户令牌')
try:
user_email = await self.ap.user_service.verify_jwt_token(token)
# 检查f是否接受user_email参数
if 'user_email' in f.__code__.co_varnames:
kwargs['user_email'] = user_email
except Exception as e:
return self.http_status(401, -1, str(e))
try:
return await f(*args, **kwargs)
except Exception as e: # 自动 500
@@ -61,25 +85,22 @@ class RouterGroup(abc.ABC):
return f
return decorator
def _cors(self, response: quart.Response) -> quart.Response:
return response
def success(self, data: typing.Any = None) -> quart.Response:
"""返回一个 200 响应"""
return self._cors(quart.jsonify({
return quart.jsonify({
'code': 0,
'msg': 'ok',
'data': data,
}))
})
def fail(self, code: int, msg: str) -> quart.Response:
"""返回一个异常响应"""
return self._cors(quart.jsonify({
return quart.jsonify({
'code': code,
'msg': msg,
}))
})
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
"""返回一个指定状态码的响应"""

View File

@@ -10,7 +10,7 @@ from .....utils import constants
class SystemRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/info', methods=['GET'])
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
async def _() -> str:
return self.success(
data={

View File

@@ -0,0 +1,43 @@
import quart
import sqlalchemy
from .. import group
from .....persistence.entities import user
@group.group_class('user', '/api/v1/user')
class UserRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'initialized': await self.ap.user_service.is_initialized()
})
if await self.ap.user_service.is_initialized():
return self.fail(1, '系统已初始化')
json_data = await quart.request.json
user_email = json_data['user']
password = json_data['password']
await self.ap.user_service.create_user(user_email, password)
return self.success()
@self.route('/auth', methods=['POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
json_data = await quart.request.json
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
return self.success(data={
'token': token
})
@self.route('/check-token', methods=['GET'])
async def _() -> str:
return self.success()

View File

@@ -7,7 +7,7 @@ import quart
import quart_cors
from ....core import app, entities as core_entities
from .groups import logs, system, settings, plugins, stats
from .groups import logs, system, settings, plugins, stats, user
from . import group

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
import sqlalchemy
import argon2
import jwt
import datetime
from ....core import app
from ....persistence.entities import user
from ....utils import constants
class UserService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def is_initialized(self) -> bool:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(user.User).limit(1)
)
result_list = result.all()
return result_list is not None and len(result_list) > 0
async def create_user(self, user_email: str, password: str) -> None:
ph = argon2.PasswordHasher()
hashed_password = ph.hash(password)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(user.User).values(
user=user_email,
password=hashed_password
)
)
async def authenticate(self, user_email: str, password: str) -> str | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(user.User).where(user.User.user == user_email)
)
result_list = result.all()
if result_list is None or len(result_list) == 0:
raise ValueError('用户不存在')
user_obj = result_list[0]
ph = argon2.PasswordHasher()
if not ph.verify(user_obj.password, password):
raise ValueError('密码错误')
return await self.generate_jwt_token(user_email)
async def generate_jwt_token(self, user_email: str) -> str:
jwt_secret = self.ap.instance_secret_meta.data['jwt_secret']
jwt_expire = self.ap.system_cfg.data['http-api']['jwt-expire']
payload = {
'user': user_email,
'iss': 'LangBot-'+constants.edition,
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire)
}
return jwt.encode(payload, jwt_secret, algorithm='HS256')
async def verify_jwt_token(self, token: str) -> str:
jwt_secret = self.ap.instance_secret_meta.data['jwt_secret']
return jwt.decode(token, jwt_secret, algorithms=['HS256'])['user']

View File

@@ -23,6 +23,7 @@ from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
from ..api.http.service import user as user_service
from ..utils import logcache, ip
from . import taskmgr
from . import entities as core_entities
@@ -74,6 +75,8 @@ class Application:
llm_models_meta: config_mgr.ConfigManager = None
instance_secret_meta: config_mgr.ConfigManager = None
# =========================
ctr_mgr: center_mgr.V2CenterAPI = None
@@ -100,6 +103,10 @@ class Application:
log_cache: logcache.LogCache = None
# ========= HTTP Services =========
user_service: user_service.UserService = None
def __init__(self):
pass

View File

@@ -21,6 +21,8 @@ required_deps = {
"aiosqlite": "aiosqlite",
"aiofiles": "aiofiles",
"aioshutil": "aioshutil",
"argon2": "argon2-cffi",
"jwt": "pyjwt",
}

View File

@@ -17,7 +17,8 @@ class HttpApiConfigMigration(migration.Migration):
self.ap.system_cfg.data['http-api'] = {
"enable": True,
"host": "0.0.0.0",
"port": 5300
"port": 5300,
"jwt-expire": 604800
}
self.ap.system_cfg.data['persistence'] = {

View File

@@ -17,6 +17,7 @@ from ...provider import runnermgr
from ...platform import manager as im_mgr
from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
from ...api.http.service import user as user_service
from ...utils import logcache
from .. import taskmgr
@@ -112,5 +113,8 @@ class BuildAppStage(stage.BootingStage):
await http_ctrl.initialize()
ap.http_ctrl = http_ctrl
user_service_inst = user_service.UserService(ap)
ap.user_service = user_service_inst
ctrl = controller.Controller(ap)
ap.ctrl = ctrl

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import secrets
from .. import stage, app
from ..bootutils import config
from ...config import settings as settings_mgr
@@ -75,3 +77,8 @@ class LoadConfigStage(stage.BootingStage):
ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
await ap.llm_models_meta.dump_config()
ap.instance_secret_meta = await config.load_json_config("data/metadata/instance-secret.json", template_data={
'jwt_secret': secrets.token_hex(16)
})
await ap.instance_secret_meta.dump_config()

View File

View File

@@ -0,0 +1,5 @@
import sqlalchemy.orm
class Base(sqlalchemy.orm.DeclarativeBase):
pass

View File

@@ -0,0 +1,11 @@
import sqlalchemy
from .base import Base
class User(Base):
__tablename__ = 'users'
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)

View File

@@ -7,6 +7,7 @@ import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
import sqlalchemy
from . import database
from .entities import user, base
from ..core import app
from .databases import sqlite
@@ -23,7 +24,7 @@ class PersistenceManager:
def __init__(self, ap: app.Application):
self.ap = ap
self.meta = sqlalchemy.MetaData()
self.meta = base.Base.metadata
async def initialize(self):
@@ -46,10 +47,11 @@ class PersistenceManager:
self,
*args,
**kwargs
):
) -> sqlalchemy.engine.cursor.CursorResult:
async with self.get_db_engine().connect() as conn:
await conn.execute(*args, **kwargs)
result = await conn.execute(*args, **kwargs)
await conn.commit()
return result
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine()

View File

@@ -1,3 +1,5 @@
semantic_version = "v3.3.1.1"
debug_mode = False
debug_mode = False
edition = 'community'