mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-08 23:06:03 +00:00
feat: 用户账户系统
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
import enum
|
||||
import quart
|
||||
from quart.typing import RouteCallable
|
||||
|
||||
@@ -23,6 +24,12 @@ def group_class(name: str, path: str) -> None:
|
||||
return decorator
|
||||
|
||||
|
||||
class AuthType(enum.Enum):
|
||||
"""认证类型"""
|
||||
NONE = 'none'
|
||||
USER_TOKEN = 'user-token'
|
||||
|
||||
|
||||
class RouterGroup(abc.ABC):
|
||||
|
||||
name: str
|
||||
@@ -41,13 +48,30 @@ class RouterGroup(abc.ABC):
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
def route(self, rule: str, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
|
||||
def route(self, rule: str, auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
|
||||
"""注册一个路由"""
|
||||
def decorator(f: RouteCallable) -> RouteCallable:
|
||||
nonlocal rule
|
||||
rule = self.path + rule
|
||||
|
||||
async def handler_error(*args, **kwargs):
|
||||
|
||||
if auth_type == AuthType.USER_TOKEN:
|
||||
# 从Authorization头中获取token
|
||||
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')
|
||||
|
||||
if not token:
|
||||
return self.http_status(401, -1, '未提供有效的用户令牌')
|
||||
|
||||
try:
|
||||
user_email = await self.ap.user_service.verify_jwt_token(token)
|
||||
|
||||
# 检查f是否接受user_email参数
|
||||
if 'user_email' in f.__code__.co_varnames:
|
||||
kwargs['user_email'] = user_email
|
||||
except Exception as e:
|
||||
return self.http_status(401, -1, str(e))
|
||||
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
except Exception as e: # 自动 500
|
||||
@@ -61,25 +85,22 @@ class RouterGroup(abc.ABC):
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
def _cors(self, response: quart.Response) -> quart.Response:
|
||||
return response
|
||||
|
||||
def success(self, data: typing.Any = None) -> quart.Response:
|
||||
"""返回一个 200 响应"""
|
||||
return self._cors(quart.jsonify({
|
||||
return quart.jsonify({
|
||||
'code': 0,
|
||||
'msg': 'ok',
|
||||
'data': data,
|
||||
}))
|
||||
})
|
||||
|
||||
def fail(self, code: int, msg: str) -> quart.Response:
|
||||
"""返回一个异常响应"""
|
||||
|
||||
return self._cors(quart.jsonify({
|
||||
return quart.jsonify({
|
||||
'code': code,
|
||||
'msg': msg,
|
||||
}))
|
||||
})
|
||||
|
||||
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
|
||||
"""返回一个指定状态码的响应"""
|
||||
|
||||
@@ -10,7 +10,7 @@ from .....utils import constants
|
||||
class SystemRouterGroup(group.RouterGroup):
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/info', methods=['GET'])
|
||||
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> str:
|
||||
return self.success(
|
||||
data={
|
||||
|
||||
43
pkg/api/http/controller/groups/user.py
Normal file
43
pkg/api/http/controller/groups/user.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import quart
|
||||
import sqlalchemy
|
||||
|
||||
from .. import group
|
||||
from .....persistence.entities import user
|
||||
|
||||
|
||||
@group.group_class('user', '/api/v1/user')
|
||||
class UserRouterGroup(group.RouterGroup):
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(data={
|
||||
'initialized': await self.ap.user_service.is_initialized()
|
||||
})
|
||||
|
||||
if await self.ap.user_service.is_initialized():
|
||||
return self.fail(1, '系统已初始化')
|
||||
|
||||
json_data = await quart.request.json
|
||||
|
||||
user_email = json_data['user']
|
||||
password = json_data['password']
|
||||
|
||||
await self.ap.user_service.create_user(user_email, password)
|
||||
|
||||
return self.success()
|
||||
|
||||
@self.route('/auth', methods=['POST'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
|
||||
|
||||
return self.success(data={
|
||||
'token': token
|
||||
})
|
||||
|
||||
@self.route('/check-token', methods=['GET'])
|
||||
async def _() -> str:
|
||||
return self.success()
|
||||
@@ -7,7 +7,7 @@ import quart
|
||||
import quart_cors
|
||||
|
||||
from ....core import app, entities as core_entities
|
||||
from .groups import logs, system, settings, plugins, stats
|
||||
from .groups import logs, system, settings, plugins, stats, user
|
||||
from . import group
|
||||
|
||||
|
||||
|
||||
74
pkg/api/http/service/user.py
Normal file
74
pkg/api/http/service/user.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy
|
||||
import argon2
|
||||
import jwt
|
||||
import datetime
|
||||
|
||||
from ....core import app
|
||||
from ....persistence.entities import user
|
||||
from ....utils import constants
|
||||
|
||||
|
||||
class UserService:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def is_initialized(self) -> bool:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(user.User).limit(1)
|
||||
)
|
||||
|
||||
result_list = result.all()
|
||||
return result_list is not None and len(result_list) > 0
|
||||
|
||||
async def create_user(self, user_email: str, password: str) -> None:
|
||||
ph = argon2.PasswordHasher()
|
||||
|
||||
hashed_password = ph.hash(password)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(user.User).values(
|
||||
user=user_email,
|
||||
password=hashed_password
|
||||
)
|
||||
)
|
||||
|
||||
async def authenticate(self, user_email: str, password: str) -> str | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(user.User).where(user.User.user == user_email)
|
||||
)
|
||||
|
||||
result_list = result.all()
|
||||
|
||||
if result_list is None or len(result_list) == 0:
|
||||
raise ValueError('用户不存在')
|
||||
|
||||
user_obj = result_list[0]
|
||||
|
||||
ph = argon2.PasswordHasher()
|
||||
|
||||
if not ph.verify(user_obj.password, password):
|
||||
raise ValueError('密码错误')
|
||||
|
||||
return await self.generate_jwt_token(user_email)
|
||||
|
||||
async def generate_jwt_token(self, user_email: str) -> str:
|
||||
jwt_secret = self.ap.instance_secret_meta.data['jwt_secret']
|
||||
jwt_expire = self.ap.system_cfg.data['http-api']['jwt-expire']
|
||||
|
||||
payload = {
|
||||
'user': user_email,
|
||||
'iss': 'LangBot-'+constants.edition,
|
||||
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire)
|
||||
}
|
||||
|
||||
return jwt.encode(payload, jwt_secret, algorithm='HS256')
|
||||
|
||||
async def verify_jwt_token(self, token: str) -> str:
|
||||
jwt_secret = self.ap.instance_secret_meta.data['jwt_secret']
|
||||
|
||||
return jwt.decode(token, jwt_secret, algorithms=['HS256'])['user']
|
||||
@@ -23,6 +23,7 @@ from ..pipeline import controller, stagemgr
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
|
||||
from ..persistence import mgr as persistencemgr
|
||||
from ..api.http.controller import main as http_controller
|
||||
from ..api.http.service import user as user_service
|
||||
from ..utils import logcache, ip
|
||||
from . import taskmgr
|
||||
from . import entities as core_entities
|
||||
@@ -74,6 +75,8 @@ class Application:
|
||||
|
||||
llm_models_meta: config_mgr.ConfigManager = None
|
||||
|
||||
instance_secret_meta: config_mgr.ConfigManager = None
|
||||
|
||||
# =========================
|
||||
|
||||
ctr_mgr: center_mgr.V2CenterAPI = None
|
||||
@@ -100,6 +103,10 @@ class Application:
|
||||
|
||||
log_cache: logcache.LogCache = None
|
||||
|
||||
# ========= HTTP Services =========
|
||||
|
||||
user_service: user_service.UserService = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ required_deps = {
|
||||
"aiosqlite": "aiosqlite",
|
||||
"aiofiles": "aiofiles",
|
||||
"aioshutil": "aioshutil",
|
||||
"argon2": "argon2-cffi",
|
||||
"jwt": "pyjwt",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ class HttpApiConfigMigration(migration.Migration):
|
||||
self.ap.system_cfg.data['http-api'] = {
|
||||
"enable": True,
|
||||
"host": "0.0.0.0",
|
||||
"port": 5300
|
||||
"port": 5300,
|
||||
"jwt-expire": 604800
|
||||
}
|
||||
|
||||
self.ap.system_cfg.data['persistence'] = {
|
||||
|
||||
@@ -17,6 +17,7 @@ from ...provider import runnermgr
|
||||
from ...platform import manager as im_mgr
|
||||
from ...persistence import mgr as persistencemgr
|
||||
from ...api.http.controller import main as http_controller
|
||||
from ...api.http.service import user as user_service
|
||||
from ...utils import logcache
|
||||
from .. import taskmgr
|
||||
|
||||
@@ -112,5 +113,8 @@ class BuildAppStage(stage.BootingStage):
|
||||
await http_ctrl.initialize()
|
||||
ap.http_ctrl = http_ctrl
|
||||
|
||||
user_service_inst = user_service.UserService(ap)
|
||||
ap.user_service = user_service_inst
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
|
||||
from .. import stage, app
|
||||
from ..bootutils import config
|
||||
from ...config import settings as settings_mgr
|
||||
@@ -75,3 +77,8 @@ class LoadConfigStage(stage.BootingStage):
|
||||
|
||||
ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
|
||||
await ap.llm_models_meta.dump_config()
|
||||
|
||||
ap.instance_secret_meta = await config.load_json_config("data/metadata/instance-secret.json", template_data={
|
||||
'jwt_secret': secrets.token_hex(16)
|
||||
})
|
||||
await ap.instance_secret_meta.dump_config()
|
||||
|
||||
0
pkg/persistence/entities/__init__.py
Normal file
0
pkg/persistence/entities/__init__.py
Normal file
5
pkg/persistence/entities/base.py
Normal file
5
pkg/persistence/entities/base.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import sqlalchemy.orm
|
||||
|
||||
|
||||
class Base(sqlalchemy.orm.DeclarativeBase):
|
||||
pass
|
||||
11
pkg/persistence/entities/user.py
Normal file
11
pkg/persistence/entities/user.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import sqlalchemy
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'users'
|
||||
|
||||
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
|
||||
user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
@@ -7,6 +7,7 @@ import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
|
||||
import sqlalchemy
|
||||
|
||||
from . import database
|
||||
from .entities import user, base
|
||||
from ..core import app
|
||||
from .databases import sqlite
|
||||
|
||||
@@ -23,7 +24,7 @@ class PersistenceManager:
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.meta = sqlalchemy.MetaData()
|
||||
self.meta = base.Base.metadata
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
@@ -46,10 +47,11 @@ class PersistenceManager:
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
) -> sqlalchemy.engine.cursor.CursorResult:
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
await conn.execute(*args, **kwargs)
|
||||
result = await conn.execute(*args, **kwargs)
|
||||
await conn.commit()
|
||||
return result
|
||||
|
||||
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
|
||||
return self.db.get_engine()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
semantic_version = "v3.3.1.1"
|
||||
|
||||
debug_mode = False
|
||||
debug_mode = False
|
||||
|
||||
edition = 'community'
|
||||
Reference in New Issue
Block a user