style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions

View File

@@ -13,6 +13,7 @@ from ....core import app
preregistered_groups: list[type[RouterGroup]] = []
"""RouterGroup 的预注册列表"""
def group_class(name: str, path: str) -> None:
"""注册一个 RouterGroup"""
@@ -27,12 +28,12 @@ def group_class(name: str, path: str) -> None:
class AuthType(enum.Enum):
"""认证类型"""
NONE = 'none'
USER_TOKEN = 'user-token'
class RouterGroup(abc.ABC):
name: str
path: str
@@ -49,17 +50,24 @@ class RouterGroup(abc.ABC):
async def initialize(self) -> None:
pass
def route(self, rule: str, auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
def route(
self,
rule: str,
auth_type: AuthType = AuthType.USER_TOKEN,
**options: typing.Any,
) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
"""注册一个路由"""
def decorator(f: RouteCallable) -> RouteCallable:
nonlocal rule
rule = self.path + rule
async def handler_error(*args, **kwargs):
if auth_type == AuthType.USER_TOKEN:
# 从Authorization头中获取token
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')
token = quart.request.headers.get('Authorization', '').replace(
'Bearer ', ''
)
if not token:
return self.http_status(401, -1, '未提供有效的用户令牌')
@@ -75,11 +83,11 @@ class RouterGroup(abc.ABC):
try:
return await f(*args, **kwargs)
except Exception as e: # 自动 500
except Exception: # 自动 500
traceback.print_exc()
# return self.http_status(500, -2, str(e))
return self.http_status(500, -2, 'internal server error')
new_f = handler_error
new_f.__name__ = (self.name + rule).replace('/', '__')
new_f.__doc__ = f.__doc__
@@ -91,20 +99,24 @@ class RouterGroup(abc.ABC):
def success(self, data: typing.Any = None) -> quart.Response:
"""返回一个 200 响应"""
return quart.jsonify({
'code': 0,
'msg': 'ok',
'data': data,
})
return quart.jsonify(
{
'code': 0,
'msg': 'ok',
'data': data,
}
)
def fail(self, code: int, msg: str) -> quart.Response:
"""返回一个异常响应"""
return quart.jsonify({
'code': code,
'msg': msg,
})
return quart.jsonify(
{
'code': code,
'msg': msg,
}
)
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
"""返回一个指定状态码的响应"""
return self.fail(code, msg), status

View File

@@ -1,32 +1,29 @@
from __future__ import annotations
import traceback
import quart
from .....core import app
from .. import group
@group.group_class('logs', '/api/v1/logs')
class LogsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
start_page_number = int(quart.request.args.get('start_page_number', 0))
start_offset = int(quart.request.args.get('start_offset', 0))
logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer(
start_page_number=start_page_number,
start_offset=start_offset
logs_str, end_page_number, end_offset = (
self.ap.log_cache.get_log_by_pointer(
start_page_number=start_page_number, start_offset=start_offset
)
)
return self.success(
data={
"logs": logs_str,
"end_page_number": end_page_number,
"end_offset": end_offset
'logs': logs_str,
'end_page_number': end_page_number,
'end_offset': end_offset,
}
)

View File

@@ -3,46 +3,41 @@ from __future__ import annotations
import quart
from .. import group
from .....entity.persistence import pipeline
@group.group_class('pipelines', '/api/v1/pipelines')
class PipelinesRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'pipelines': await self.ap.pipeline_service.get_pipelines()
})
return self.success(
data={'pipelines': await self.ap.pipeline_service.get_pipelines()}
)
elif quart.request.method == 'POST':
json_data = await quart.request.json
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(json_data)
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(
json_data
)
return self.success(data={
'uuid': pipeline_uuid
})
return self.success(data={'uuid': pipeline_uuid})
@self.route('/_/metadata', methods=['GET'])
async def _() -> str:
return self.success(data={
'configs': await self.ap.pipeline_service.get_pipeline_metadata()
})
return self.success(
data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()}
)
@self.route('/<pipeline_uuid>', methods=['GET', 'PUT', 'DELETE'])
async def _(pipeline_uuid: str) -> str:
if quart.request.method == 'GET':
pipeline = await self.ap.pipeline_service.get_pipeline(pipeline_uuid)
if pipeline is None:
return self.http_status(404, -1, 'pipeline not found')
return self.success(data={
'pipeline': pipeline
})
return self.success(data={'pipeline': pipeline})
elif quart.request.method == 'PUT':
json_data = await quart.request.json
@@ -53,4 +48,3 @@ class PipelinesRouterGroup(group.RouterGroup):
await self.ap.pipeline_service.delete_pipeline(pipeline_uuid)
return self.success()

View File

@@ -5,29 +5,31 @@ from ... import group
@group.group_class('adapters', '/api/v1/platform/adapters')
class AdaptersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> str:
return self.success(data={
'adapters': self.ap.platform_mgr.get_available_adapters_info()
})
return self.success(
data={'adapters': self.ap.platform_mgr.get_available_adapters_info()}
)
@self.route('/<adapter_name>', methods=['GET'])
async def _(adapter_name: str) -> str:
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(adapter_name)
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(
adapter_name
)
if adapter_info is None:
return self.http_status(404, -1, 'adapter not found')
return self.success(data={
'adapter': adapter_info
})
return self.success(data={'adapter': adapter_info})
@self.route('/<adapter_name>/icon', methods=['GET'])
async def _(adapter_name: str) -> quart.Response:
adapter_manifest = self.ap.platform_mgr.get_available_adapter_manifest_by_name(adapter_name)
adapter_manifest = (
self.ap.platform_mgr.get_available_adapter_manifest_by_name(
adapter_name
)
)
if adapter_manifest is None:
return self.http_status(404, -1, 'adapter not found')
@@ -37,4 +39,4 @@ class AdaptersRouterGroup(group.RouterGroup):
if icon_path is None:
return self.http_status(404, -1, 'icon not found')
return await quart.send_file(icon_path)
return await quart.send_file(icon_path)

View File

@@ -5,34 +5,27 @@ from ... import group
@group.group_class('bots', '/api/v1/platform/bots')
class BotsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'bots': await self.ap.bot_service.get_bots()
})
return self.success(data={'bots': await self.ap.bot_service.get_bots()})
elif quart.request.method == 'POST':
json_data = await quart.request.json
bot_uuid = await self.ap.bot_service.create_bot(json_data)
return self.success(data={
'uuid': bot_uuid
})
return self.success(data={'uuid': bot_uuid})
@self.route('/<bot_uuid>', methods=['GET', 'PUT', 'DELETE'])
async def _(bot_uuid: str) -> str:
if quart.request.method == 'GET':
bot = await self.ap.bot_service.get_bot(bot_uuid)
if bot is None:
return self.http_status(404, -1, 'bot not found')
return self.success(data={
'bot': bot
})
return self.success(data={'bot': bot})
elif quart.request.method == 'PUT':
json_data = await quart.request.json
await self.ap.bot_service.update_bot(bot_uuid, json_data)
return self.success()
elif quart.request.method == 'DELETE':
await self.ap.bot_service.delete_bot(bot_uuid)
return self.success()
return self.success()

View File

@@ -1,17 +1,14 @@
from __future__ import annotations
import traceback
import quart
from .....core import app, taskmgr
from .....core import taskmgr
from .. import group
@group.group_class('plugins', '/api/v1/plugins')
class PluginsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
@@ -19,63 +16,69 @@ class PluginsRouterGroup(group.RouterGroup):
plugins_data = [plugin.model_dump() for plugin in plugins]
return self.success(data={
'plugins': plugins_data
})
@self.route('/<author>/<plugin_name>/toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
return self.success(data={'plugins': plugins_data})
@self.route(
'/<author>/<plugin_name>/toggle',
methods=['PUT'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
data = await quart.request.json
target_enabled = data.get('target_enabled')
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
return self.success()
@self.route('/<author>/<plugin_name>/update', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
@self.route(
'/<author>/<plugin_name>/update',
methods=['POST'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
kind="plugin-operation",
name=f"plugin-update-{plugin_name}",
label=f"更新插件 {plugin_name}",
context=ctx
kind='plugin-operation',
name=f'plugin-update-{plugin_name}',
label=f'更新插件 {plugin_name}',
context=ctx,
)
return self.success(data={
'task_id': wrapper.id
})
@self.route('/<author>/<plugin_name>', methods=['GET', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
return self.success(data={'task_id': wrapper.id})
@self.route(
'/<author>/<plugin_name>',
methods=['GET', 'DELETE'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
if quart.request.method == 'GET':
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
if plugin is None:
return self.http_status(404, -1, 'plugin not found')
return self.success(data={
'plugin': plugin.model_dump()
})
return self.success(data={'plugin': plugin.model_dump()})
elif quart.request.method == 'DELETE':
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
kind="plugin-operation",
kind='plugin-operation',
name=f'plugin-remove-{plugin_name}',
label=f'删除插件 {plugin_name}',
context=ctx
context=ctx,
)
return self.success(data={
'task_id': wrapper.id
})
@self.route('/<author>/<plugin_name>/config', methods=['GET', 'PUT'], auth_type=group.AuthType.USER_TOKEN)
return self.success(data={'task_id': wrapper.id})
@self.route(
'/<author>/<plugin_name>/config',
methods=['GET', 'PUT'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> quart.Response:
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
if plugin is None:
return self.http_status(404, -1, 'plugin not found')
if quart.request.method == 'GET':
return self.success(data={
'config': plugin.plugin_config
})
return self.success(data={'config': plugin.plugin_config})
elif quart.request.method == 'PUT':
data = await quart.request.json
@@ -88,21 +91,21 @@ class PluginsRouterGroup(group.RouterGroup):
data = await quart.request.json
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
return self.success()
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
@self.route(
'/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
async def _() -> str:
data = await quart.request.json
ctx = taskmgr.TaskContext.new()
short_source_str = data['source'][-8:]
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
kind="plugin-operation",
name=f'plugin-install-github',
kind='plugin-operation',
name='plugin-install-github',
label=f'安装插件 ...{short_source_str}',
context=ctx
context=ctx,
)
return self.success(data={
'task_id': wrapper.id
})
return self.success(data={'task_id': wrapper.id})

View File

@@ -1,28 +1,23 @@
import quart
import uuid
from ... import group
from ......entity.persistence import model
@group.group_class('models/llm', '/api/v1/provider/models/llm')
class LLMModelsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'models': await self.ap.model_service.get_llm_models()
})
return self.success(
data={'models': await self.ap.model_service.get_llm_models()}
)
elif quart.request.method == 'POST':
json_data = await quart.request.json
model_uuid = await self.ap.model_service.create_llm_model(json_data)
return self.success(data={
'uuid': model_uuid
})
return self.success(data={'uuid': model_uuid})
@self.route('/<model_uuid>', methods=['GET', 'DELETE'])
async def _(model_uuid: str) -> str:
@@ -32,9 +27,7 @@ class LLMModelsRouterGroup(group.RouterGroup):
if model is None:
return self.http_status(404, -1, 'model not found')
return self.success(data={
'model': model
})
return self.success(data={'model': model})
# elif quart.request.method == 'PUT':
# json_data = await quart.request.json

View File

@@ -5,29 +5,31 @@ from ... import group
@group.group_class('provider/requesters', '/api/v1/provider/requesters')
class RequestersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> quart.Response:
return self.success(data={
'requesters': self.ap.model_mgr.get_available_requesters_info()
})
return self.success(
data={'requesters': self.ap.model_mgr.get_available_requesters_info()}
)
@self.route('/<requester_name>', methods=['GET'])
async def _(requester_name: str) -> quart.Response:
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(requester_name)
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(
requester_name
)
if requester_info is None:
return self.http_status(404, -1, 'requester not found')
return self.success(data={
'requester': requester_info
})
return self.success(data={'requester': requester_info})
@self.route('/<requester_name>/icon', methods=['GET'])
async def _(requester_name: str) -> quart.Response:
requester_manifest = self.ap.model_mgr.get_available_requester_manifest_by_name(requester_name)
requester_manifest = (
self.ap.model_mgr.get_available_requester_manifest_by_name(
requester_name
)
)
if requester_manifest is None:
return self.http_status(404, -1, 'requester not found')

View File

@@ -1,23 +1,21 @@
import quart
import asyncio
from .....core import app, taskmgr
from .. import group
@group.group_class('stats', '/api/v1/stats')
class StatsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/basic', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
conv_count = 0
for session in self.ap.sess_mgr.session_list:
conv_count += len(session.conversations if session.conversations is not None else [])
conv_count += len(
session.conversations if session.conversations is not None else []
)
return self.success(data={
'active_session_count': len(self.ap.sess_mgr.session_list),
'conversation_count': conv_count,
'query_count': self.ap.query_pool.query_id_counter,
})
return self.success(
data={
'active_session_count': len(self.ap.sess_mgr.session_list),
'conversation_count': conv_count,
'query_count': self.ap.query_pool.query_id_counter,
}
)

View File

@@ -1,63 +1,62 @@
import quart
import asyncio
from .....core import app, taskmgr
from .. import group
from .....utils import constants
@group.group_class('system', '/api/v1/system')
class SystemRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
async def _() -> str:
return self.success(
data={
"version": constants.semantic_version,
"debug": constants.debug_mode,
"enabled_platform_count": len(self.ap.platform_mgr.get_running_adapters())
'version': constants.semantic_version,
'debug': constants.debug_mode,
'enabled_platform_count': len(
self.ap.platform_mgr.get_running_adapters()
),
}
)
@self.route('/tasks', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
task_type = quart.request.args.get("type")
task_type = quart.request.args.get('type')
if task_type == '':
task_type = None
return self.success(
data=self.ap.task_mgr.get_tasks_dict(task_type)
)
@self.route('/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type))
@self.route(
'/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
async def _(task_id: str) -> str:
task = self.ap.task_mgr.get_task_by_id(int(task_id))
if task is None:
return self.http_status(404, 404, "Task not found")
return self.http_status(404, 404, 'Task not found')
return self.success(data=task.to_dict())
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
json_data = await quart.request.json
scope = json_data.get("scope")
scope = json_data.get('scope')
await self.ap.reload(
scope=scope
)
await self.ap.reload(scope=scope)
return self.success()
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
@self.route(
'/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, "Forbidden")
return self.http_status(403, 403, 'Forbidden')
py_code = await quart.request.data
ap = self.ap
return self.success(data=exec(py_code, {"ap": ap}))
return self.success(data=exec(py_code, {'ap': ap}))

View File

@@ -1,22 +1,19 @@
import quart
import jwt
import argon2
from .. import group
from .....entity.persistence import user
@group.group_class('user', '/api/v1/user')
class UserRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'initialized': await self.ap.user_service.is_initialized()
})
return self.success(
data={'initialized': await self.ap.user_service.is_initialized()}
)
if await self.ap.user_service.is_initialized():
return self.fail(1, '系统已初始化')
@@ -28,24 +25,24 @@ class UserRouterGroup(group.RouterGroup):
await self.ap.user_service.create_user(user_email, password)
return self.success()
@self.route('/auth', methods=['POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
json_data = await quart.request.json
try:
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
token = await self.ap.user_service.authenticate(
json_data['user'], json_data['password']
)
except argon2.exceptions.VerifyMismatchError:
return self.fail(1, '用户名或密码错误')
return self.success(data={
'token': token
})
return self.success(data={'token': token})
@self.route('/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
@self.route(
'/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
async def _(user_email: str) -> str:
token = await self.ap.user_service.generate_jwt_token(user_email)
return self.success(data={
'token': token
})
return self.success(data={'token': token})

View File

@@ -7,15 +7,19 @@ import quart
import quart_cors
from ....core import app, entities as core_entities
from ....utils import importutil
from .groups import logs, system, plugins, stats, user, pipelines
from .groups.provider import models, requesters
from .groups.platform import bots, adapters
from . import groups
from . import group
from .groups import provider as groups_provider
from .groups import platform as groups_platform
importutil.import_modules_in_pkg(groups)
importutil.import_modules_in_pkg(groups_provider)
importutil.import_modules_in_pkg(groups_platform)
class HTTPController:
ap: app.Application
quart_app: quart.Quart
@@ -23,7 +27,7 @@ class HTTPController:
def __init__(self, ap: app.Application) -> None:
self.ap = ap
self.quart_app = quart.Quart(__name__)
quart_cors.cors(self.quart_app, allow_origin="*")
quart_cors.cors(self.quart_app, allow_origin='*')
async def initialize(self) -> None:
await self.register_routes()
@@ -37,11 +41,9 @@ class HTTPController:
async def exception_handler(*args, **kwargs):
try:
await self.quart_app.run_task(
*args, **kwargs
)
await self.quart_app.run_task(*args, **kwargs)
except Exception as e:
self.ap.logger.error(f"启动 HTTP 服务失败: {e}")
self.ap.logger.error(f'启动 HTTP 服务失败: {e}')
self.ap.task_mgr.create_task(
exception_handler(
@@ -49,63 +51,62 @@ class HTTPController:
port=self.ap.instance_config.data['api']['port'],
shutdown_trigger=shutdown_trigger_placeholder,
),
name="http-api-quart",
name='http-api-quart',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
# await asyncio.sleep(5)
async def register_routes(self) -> None:
@self.quart_app.route("/healthz")
@self.quart_app.route('/healthz')
async def healthz():
return {"code": 0, "msg": "ok"}
return {'code': 0, 'msg': 'ok'}
for g in group.preregistered_groups:
ginst = g(self.ap, self.quart_app)
await ginst.initialize()
frontend_path = "web/out"
frontend_path = 'web/out'
@self.quart_app.route("/")
@self.quart_app.route('/')
async def index():
return await quart.send_from_directory(frontend_path, "index.html", mimetype="text/html")
return await quart.send_from_directory(
frontend_path, 'index.html', mimetype='text/html'
)
@self.quart_app.route("/<path:path>")
@self.quart_app.route('/<path:path>')
async def static_file(path: str):
if not os.path.exists(os.path.join(frontend_path, path)):
if os.path.exists(os.path.join(frontend_path, path+".html")):
if os.path.exists(os.path.join(frontend_path, path + '.html')):
path += '.html'
else:
return await quart.send_from_directory(frontend_path, '404.html')
mimetype = None
if path.endswith(".html"):
mimetype = "text/html"
elif path.endswith(".js"):
mimetype = "application/javascript"
elif path.endswith(".css"):
mimetype = "text/css"
elif path.endswith(".png"):
mimetype = "image/png"
elif path.endswith(".jpg"):
mimetype = "image/jpeg"
elif path.endswith(".jpeg"):
mimetype = "image/jpeg"
elif path.endswith(".gif"):
mimetype = "image/gif"
elif path.endswith(".svg"):
mimetype = "image/svg+xml"
elif path.endswith(".ico"):
mimetype = "image/x-icon"
elif path.endswith(".json"):
mimetype = "application/json"
elif path.endswith(".txt"):
mimetype = "text/plain"
if path.endswith('.html'):
mimetype = 'text/html'
elif path.endswith('.js'):
mimetype = 'application/javascript'
elif path.endswith('.css'):
mimetype = 'text/css'
elif path.endswith('.png'):
mimetype = 'image/png'
elif path.endswith('.jpg'):
mimetype = 'image/jpeg'
elif path.endswith('.jpeg'):
mimetype = 'image/jpeg'
elif path.endswith('.gif'):
mimetype = 'image/gif'
elif path.endswith('.svg'):
mimetype = 'image/svg+xml'
elif path.endswith('.ico'):
mimetype = 'image/x-icon'
elif path.endswith('.json'):
mimetype = 'application/json'
elif path.endswith('.txt'):
mimetype = 'text/plain'
return await quart.send_from_directory(
frontend_path,
path,
mimetype=mimetype
frontend_path, path, mimetype=mimetype
)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import uuid
import datetime
import sqlalchemy
from ....core import app
@@ -29,13 +28,15 @@ class BotService:
self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot)
for bot in bots
]
async def get_bot(self, bot_uuid: str) -> dict | None:
"""获取机器人"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.select(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
)
bot = result.first()
if bot is None:
@@ -50,7 +51,9 @@ class BotService:
# checkout the default pipeline
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True)
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
)
)
pipeline = result.first()
if pipeline is not None:
@@ -64,7 +67,7 @@ class BotService:
bot = await self.get_bot(bot_data['uuid'])
await self.ap.platform_mgr.load_bot(bot)
return bot_data['uuid']
async def update_bot(self, bot_uuid: str, bot_data: dict) -> None:
@@ -75,19 +78,24 @@ class BotService:
# set use_pipeline_name
if 'use_pipeline_uuid' in bot_data:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid'])
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid
== bot_data['use_pipeline_uuid']
)
)
pipeline = result.first()
if pipeline is not None:
bot_data['use_pipeline_name'] = pipeline.name
else:
raise Exception("Pipeline not found")
raise Exception('Pipeline not found')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.update(persistence_bot.Bot)
.values(bot_data)
.where(persistence_bot.Bot.uuid == bot_uuid)
)
await self.ap.platform_mgr.remove_bot(bot_uuid)
# select from db
bot = await self.get_bot(bot_uuid)
@@ -100,7 +108,7 @@ class BotService:
"""删除机器人"""
await self.ap.platform_mgr.remove_bot(bot_uuid)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.delete(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import uuid
import datetime
import sqlalchemy
from ....core import app
@@ -10,7 +9,6 @@ from ....entity.persistence import pipeline as persistence_pipeline
class ModelsService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
@@ -26,15 +24,12 @@ class ModelsService:
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
for model in models
]
async def create_llm_model(self, model_data: dict) -> str:
async def create_llm_model(self, model_data: dict) -> str:
model_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.LLMModel).values(
**model_data
)
sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)
)
llm_model = await self.get_llm_model(model_data['uuid'])
@@ -43,22 +38,24 @@ class ModelsService:
# check if default pipeline has no model bound
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True)
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
)
)
pipeline = result.first()
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
pipeline_config = pipeline.config
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
pipeline_data = {
"config": pipeline_config
}
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
pipeline_data = {'config': pipeline_config}
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
return model_data['uuid']
async def get_llm_model(self, model_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
sqlalchemy.select(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
)
model = result.first()
@@ -66,14 +63,18 @@ class ModelsService:
if model is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
return self.ap.persistence_mgr.serialize_model(
persistence_model.LLMModel, model
)
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
if 'uuid' in model_data:
del model_data['uuid']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data)
sqlalchemy.update(persistence_model.LLMModel)
.where(persistence_model.LLMModel.uuid == model_uuid)
.values(**model_data)
)
await self.ap.model_mgr.remove_llm_model(model_uuid)
@@ -84,7 +85,9 @@ class ModelsService:
async def delete_llm_model(self, model_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
sqlalchemy.delete(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
)
await self.ap.model_mgr.remove_llm_model(model_uuid)

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import uuid
import json
import datetime
import sqlalchemy
from ....core import app
@@ -10,69 +9,79 @@ from ....entity.persistence import pipeline as persistence_pipeline
default_stage_order = [
"GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", # 预处理器
"ConversationMessageTruncator", # 会话消息截断器
"RequireRateLimitOccupancy", # 请求速率限制占用
"MessageProcessor", # 处理器
"ReleaseRateLimitOccupancy", # 释放速率限制占用
"PostContentFilterStage", # 内容过滤后置阶段
"ResponseWrapper", # 响应包装器
"LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
'GroupRespondRuleCheckStage', # 群响应规则检查
'BanSessionCheckStage', # 封禁会话检查
'PreContentFilterStage', # 内容过滤前置阶段
'PreProcessor', # 预处理器
'ConversationMessageTruncator', # 会话消息截断器
'RequireRateLimitOccupancy', # 请求速率限制占用
'MessageProcessor', # 处理器
'ReleaseRateLimitOccupancy', # 释放速率限制占用
'PostContentFilterStage', # 内容过滤后置阶段
'ResponseWrapper', # 响应包装器
'LongTextProcessStage', # 长文本处理
'SendResponseBackStage', # 发送响应
]
class PipelineService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_pipeline_metadata(self) -> dict:
return [
self.ap.pipeline_config_meta_trigger.data,
self.ap.pipeline_config_meta_safety.data,
self.ap.pipeline_config_meta_ai.data,
self.ap.pipeline_config_meta_output.data
self.ap.pipeline_config_meta_output.data,
]
async def get_pipelines(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
)
pipelines = result.all()
return [
self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
for pipeline in pipelines
]
async def get_pipeline(self, pipeline_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid
)
)
pipeline = result.first()
if pipeline is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
return self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str:
pipeline_data['uuid'] = str(uuid.uuid4())
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
pipeline_data['stages'] = default_stage_order.copy()
pipeline_data['is_default'] = default
pipeline_data['config'] = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
pipeline_data['config'] = json.load(
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data)
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(
**pipeline_data
)
)
pipeline = await self.get_pipeline(pipeline_data['uuid'])
await self.ap.pipeline_mgr.load_pipeline(pipeline)
@@ -90,7 +99,9 @@ class PipelineService:
del pipeline_data['is_default']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data)
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
.where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
.values(**pipeline_data)
)
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
@@ -101,6 +112,8 @@ class PipelineService:
async def delete_pipeline(self, pipeline_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid
)
)
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)

View File

@@ -11,7 +11,6 @@ from ....utils import constants
class UserService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
@@ -24,7 +23,7 @@ class UserService:
result_list = result.all()
return result_list is not None and len(result_list) > 0
async def create_user(self, user_email: str, password: str) -> None:
ph = argon2.PasswordHasher()
@@ -32,8 +31,7 @@ class UserService:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(user.User).values(
user=user_email,
password=hashed_password
user=user_email, password=hashed_password
)
)
@@ -61,12 +59,12 @@ class UserService:
payload = {
'user': user_email,
'iss': 'LangBot-'+constants.edition,
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire)
'iss': 'LangBot-' + constants.edition,
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire),
}
return jwt.encode(payload, jwt_secret, algorithm='HS256')
async def verify_jwt_token(self, token: str) -> str:
jwt_secret = self.ap.instance_config.data['system']['jwt']['secret']

View File

@@ -1,3 +1,3 @@
"""
审计相关操作
"""
"""

View File

@@ -3,11 +3,9 @@ from __future__ import annotations
import abc
import uuid
import json
import logging
import asyncio
import aiohttp
import requests
from ...core import app, entities as core_entities
@@ -38,22 +36,22 @@ class APIGroup(metaclass=abc.ABCMeta):
"""
执行请求
"""
self._runtime_info["account_id"] = "-1"
self._runtime_info['account_id'] = '-1'
url = self.prefix + path
data = json.dumps(data)
headers["Content-Type"] = "application/json"
headers['Content-Type'] = 'application/json'
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method, url, data=data, params=params, headers=headers, **kwargs
) as resp:
self.ap.logger.debug("data: %s", data)
self.ap.logger.debug("ret: %s", await resp.text())
self.ap.logger.debug('data: %s', data)
self.ap.logger.debug('ret: %s', await resp.text())
except Exception as e:
self.ap.logger.debug(f"上报失败: {e}")
self.ap.logger.debug(f'上报失败: {e}')
async def do(
self,
@@ -68,8 +66,8 @@ class APIGroup(metaclass=abc.ABCMeta):
return self.ap.task_mgr.create_task(
self._do(method, path, data, params, headers, **kwargs),
kind="telemetry-operation",
name=f"{method} {path}",
kind='telemetry-operation',
name=f'{method} {path}',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
).task
@@ -80,7 +78,7 @@ class APIGroup(metaclass=abc.ABCMeta):
def basic_info(self):
"""获取基本信息"""
basic_info = APIGroup._basic_info.copy()
basic_info["rid"] = self.gen_rid()
basic_info['rid'] = self.gen_rid()
return basic_info
def runtime_info(self):

View File

@@ -9,7 +9,7 @@ class V2MainDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/main", ap)
super().__init__(prefix + '/main', ap)
async def do(self, *args, **kwargs):
if not self.ap.instance_config.data['telemetry']['report']:
@@ -25,31 +25,31 @@ class V2MainDataAPI(apigroup.APIGroup):
):
"""提交更新记录"""
return await self.do(
"POST",
"/update",
'POST',
'/update',
data={
"basic": self.basic_info(),
"update_info": {
"spent_seconds": spent_seconds,
"infer_reason": infer_reason,
"old_version": old_version,
"new_version": new_version,
}
}
'basic': self.basic_info(),
'update_info': {
'spent_seconds': spent_seconds,
'infer_reason': infer_reason,
'old_version': old_version,
'new_version': new_version,
},
},
)
async def post_announcement_showed(
self,
ids: list[int],
):
"""提交公告已阅"""
return await self.do(
"POST",
"/announcement",
'POST',
'/announcement',
data={
"basic": self.basic_info(),
"announcement_info": {
"ids": ids,
}
}
'basic': self.basic_info(),
'announcement_info': {
'ids': ids,
},
},
)

View File

@@ -9,39 +9,33 @@ class V2PluginDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/plugin", ap)
super().__init__(prefix + '/plugin', ap)
async def do(self, *args, **kwargs):
if not self.ap.instance_config.data['telemetry']['report']:
return None
return await super().do(*args, **kwargs)
async def post_install_record(
self,
plugin: dict
):
async def post_install_record(self, plugin: dict):
"""提交插件安装记录"""
return await self.do(
"POST",
"/install",
'POST',
'/install',
data={
"basic": self.basic_info(),
"plugin": plugin,
}
'basic': self.basic_info(),
'plugin': plugin,
},
)
async def post_remove_record(
self,
plugin: dict
):
async def post_remove_record(self, plugin: dict):
"""提交插件卸载记录"""
return await self.do(
"POST",
"/remove",
'POST',
'/remove',
data={
"basic": self.basic_info(),
"plugin": plugin,
}
'basic': self.basic_info(),
'plugin': plugin,
},
)
async def post_update_record(
@@ -52,14 +46,14 @@ class V2PluginDataAPI(apigroup.APIGroup):
):
"""提交插件更新记录"""
return await self.do(
"POST",
"/update",
'POST',
'/update',
data={
"basic": self.basic_info(),
"plugin": plugin,
"update_info": {
"old_version": old_version,
"new_version": new_version,
}
}
'basic': self.basic_info(),
'plugin': plugin,
'update_info': {
'old_version': old_version,
'new_version': new_version,
},
},
)

View File

@@ -9,7 +9,7 @@ class V2UsageDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage", ap)
super().__init__(prefix + '/usage', ap)
async def do(self, *args, **kwargs):
if not self.ap.instance_config.data['telemetry']['report']:
@@ -28,25 +28,25 @@ class V2UsageDataAPI(apigroup.APIGroup):
):
"""提交请求记录"""
return await self.do(
"POST",
"/query",
'POST',
'/query',
data={
"basic": self.basic_info(),
"runtime": self.runtime_info(),
"session_info": {
"type": session_type,
"id": session_id,
'basic': self.basic_info(),
'runtime': self.runtime_info(),
'session_info': {
'type': session_type,
'id': session_id,
},
"query_info": {
"ability_provider": query_ability_provider,
"usage": usage,
"model_name": model_name,
"response_seconds": response_seconds,
"retry_times": retry_times,
}
}
'query_info': {
'ability_provider': query_ability_provider,
'usage': usage,
'model_name': model_name,
'response_seconds': response_seconds,
'retry_times': retry_times,
},
},
)
async def post_event_record(
self,
plugins: list[dict],
@@ -54,18 +54,18 @@ class V2UsageDataAPI(apigroup.APIGroup):
):
"""提交事件触发记录"""
return await self.do(
"POST",
"/event",
'POST',
'/event',
data={
"basic": self.basic_info(),
"runtime": self.runtime_info(),
"plugins": plugins,
"event_info": {
"name": event_name,
}
}
'basic': self.basic_info(),
'runtime': self.runtime_info(),
'plugins': plugins,
'event_info': {
'name': event_name,
},
},
)
async def post_function_record(
self,
plugin: dict,
@@ -74,15 +74,14 @@ class V2UsageDataAPI(apigroup.APIGroup):
):
"""提交内容函数使用记录"""
return await self.do(
"POST",
"/function",
'POST',
'/function',
data={
"basic": self.basic_info(),
"plugin": plugin,
"function_info": {
"name": function_name,
"description": function_description,
}
}
'basic': self.basic_info(),
'plugin': plugin,
'function_info': {
'name': function_name,
'description': function_description,
},
},
)

View File

@@ -11,7 +11,7 @@ from ...core import app
class V2CenterAPI:
"""中央服务器 v2 API 交互类"""
main: main.V2MainDataAPI = None
"""主 API 组"""
@@ -21,15 +21,20 @@ class V2CenterAPI:
plugin: plugin.V2PluginDataAPI = None
"""插件 API 组"""
def __init__(self, ap: app.Application, backend_url: str, basic_info: dict = None, runtime_info: dict = None):
def __init__(
self,
ap: app.Application,
backend_url: str,
basic_info: dict = None,
runtime_info: dict = None,
):
"""初始化"""
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info)
logging.debug('basic_info: %s, runtime_info: %s', basic_info, runtime_info)
apigroup.APIGroup._basic_info = basic_info
apigroup.APIGroup._runtime_info = runtime_info
self.main = main.V2MainDataAPI(backend_url, ap)
self.usage = usage.V2UsageDataAPI(backend_url, ap)
self.plugin = plugin.V2PluginDataAPI(backend_url, ap)

View File

@@ -16,6 +16,7 @@ identifier = {
HOST_ID_FILE = os.path.expanduser('~/.langbot/host_id.json')
INSTANCE_ID_FILE = 'data/labels/instance_id.json'
def init():
global identifier
@@ -23,14 +24,11 @@ def init():
os.mkdir(os.path.expanduser('~/.langbot'))
if not os.path.exists(HOST_ID_FILE):
new_host_id = 'host_'+str(uuid.uuid4())
new_host_id = 'host_' + str(uuid.uuid4())
new_host_create_ts = int(time.time())
with open(HOST_ID_FILE, 'w') as f:
json.dump({
'host_id': new_host_id,
'host_create_ts': new_host_create_ts
}, f)
json.dump({'host_id': new_host_id, 'host_create_ts': new_host_create_ts}, f)
identifier['host_id'] = new_host_id
identifier['host_create_ts'] = new_host_create_ts
@@ -51,20 +49,25 @@ def init():
instance_id = {}
with open(INSTANCE_ID_FILE, 'r') as f:
instance_id = json.load(f)
if instance_id['host_id'] != identifier['host_id']: # 如果实例 id 不是当前主机的,删除
if (
instance_id['host_id'] != identifier['host_id']
): # 如果实例 id 不是当前主机的,删除
os.remove(INSTANCE_ID_FILE)
if not os.path.exists(INSTANCE_ID_FILE):
new_instance_id = 'instance_'+str(uuid.uuid4())
new_instance_id = 'instance_' + str(uuid.uuid4())
new_instance_create_ts = int(time.time())
with open(INSTANCE_ID_FILE, 'w') as f:
json.dump({
'host_id': identifier['host_id'],
'instance_id': new_instance_id,
'instance_create_ts': new_instance_create_ts
}, f)
json.dump(
{
'host_id': identifier['host_id'],
'instance_id': new_instance_id,
'instance_create_ts': new_instance_create_ts,
},
f,
)
identifier['instance_id'] = new_instance_id
identifier['instance_create_ts'] = new_instance_create_ts
@@ -80,6 +83,7 @@ def init():
identifier['instance_id'] = loaded_instance_id
identifier['instance_create_ts'] = loaded_instance_create_ts
def print_out():
global identifier
print(identifier)

View File

@@ -3,17 +3,17 @@ from __future__ import annotations
import typing
from ..core import app, entities as core_entities
from ..provider import entities as llm_entities
from . import entities, operator, errors
from ..config import manager as cfg_mgr
from ..utils import importutil
# 引入所有算子以便注册
from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
from . import operators
importutil.import_modules_in_pkg(operators)
class CommandManager:
"""命令管理器
"""
"""命令管理器"""
ap: app.Application
@@ -26,14 +26,13 @@ class CommandManager:
self.ap = ap
async def initialize(self):
# 设置各个类的路径
def set_path(cls: operator.CommandOperator, ancestors: list[str]):
cls.path = '.'.join(ancestors + [cls.name])
for op in operator.preregistered_operators:
if op.parent_class == cls:
set_path(op, ancestors + [cls.name])
for cls in operator.preregistered_operators:
if cls.parent_class is None:
set_path(cls, [])
@@ -41,14 +40,18 @@ class CommandManager:
# 应用命令权限配置
for cls in operator.preregistered_operators:
if cls.path in self.ap.instance_config.data['command']['privilege']:
cls.lowest_privilege = self.ap.instance_config.data['command']['privilege'][cls.path]
cls.lowest_privilege = self.ap.instance_config.data['command'][
'privilege'
][cls.path]
# 实例化所有类
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
# 设置所有类的子节点
for cmd in self.cmd_list:
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
cmd.children = [
child for child in self.cmd_list if child.parent_class == cmd.__class__
]
# 初始化所有类
for cmd in self.cmd_list:
@@ -58,27 +61,25 @@ class CommandManager:
self,
context: entities.ExecuteContext,
operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None
operator: operator.CommandOperator = None,
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
"""执行命令"""
found = False
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list:
if (context.crt_params[0] == oper.name \
or context.crt_params[0] in oper.alias) \
and (oper.parent_class is None or oper.parent_class == operator.__class__):
if (
context.crt_params[0] == oper.name
or context.crt_params[0] in oper.alias
) and (
oper.parent_class is None or oper.parent_class == operator.__class__
):
found = True
context.crt_command = context.crt_params[0]
context.crt_params = context.crt_params[1:]
async for ret in self._execute(
context,
oper.children,
oper
):
async for ret in self._execute(context, oper.children, oper):
yield ret
break
@@ -96,19 +97,20 @@ class CommandManager:
async for ret in operator.execute(context):
yield ret
async def execute(
self,
command_text: str,
query: core_entities.Query,
session: core_entities.Session
session: core_entities.Session,
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
"""执行命令"""
privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
if (
f'{query.launcher_type.value}_{query.launcher_id}'
in self.ap.instance_config.data['admins']
):
privilege = 2
ctx = entities.ExecuteContext(
@@ -119,11 +121,8 @@ class CommandManager:
crt_command='',
params=command_text.split(' '),
crt_params=command_text.split(' '),
privilege=privilege
privilege=privilege,
)
async for ret in self._execute(
ctx,
self.cmd_list
):
async for ret in self._execute(ctx, self.cmd_list):
yield ret

View File

@@ -4,14 +4,13 @@ import typing
import pydantic.v1 as pydantic
from ..core import app, entities as core_entities
from . import errors, operator
from ..core import entities as core_entities
from . import errors
from ..platform.types import message as platform_message
class CommandReturn(pydantic.BaseModel):
"""命令返回值
"""
"""命令返回值"""
text: typing.Optional[str] = None
"""文本
@@ -24,7 +23,7 @@ class CommandReturn(pydantic.BaseModel):
"""图片链接
"""
error: typing.Optional[errors.CommandError]= None
error: typing.Optional[errors.CommandError] = None
"""错误
"""
@@ -33,8 +32,7 @@ class CommandReturn(pydantic.BaseModel):
class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文
"""
"""单次命令执行上下文"""
query: core_entities.Query
"""本次消息的请求对象"""

View File

@@ -1,33 +1,26 @@
class CommandError(Exception):
def __init__(self, message: str = None):
self.message = message
def __str__(self):
return self.message
class CommandNotFoundError(CommandError):
def __init__(self, message: str = None):
super().__init__("未知命令: "+message)
super().__init__('未知命令: ' + message)
class CommandPrivilegeError(CommandError):
def __init__(self, message: str = None):
super().__init__("权限不足: "+message)
super().__init__('权限不足: ' + message)
class ParamNotEnoughError(CommandError):
def __init__(self, message: str = None):
super().__init__("参数不足: "+message)
super().__init__('参数不足: ' + message)
class CommandOperationError(CommandError):
def __init__(self, message: str = None):
super().__init__("操作失败: "+message)
super().__init__('操作失败: ' + message)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import typing
import abc
from ..core import app, entities as core_entities
from ..core import app
from . import entities
@@ -13,14 +13,14 @@ preregistered_operators: list[typing.Type[CommandOperator]] = []
def operator_class(
name: str,
help: str = "",
help: str = '',
usage: str = None,
alias: list[str] = [],
privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None
privilege: int = 1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None,
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
"""命令类装饰器
Args:
name (str): 名称
help (str, optional): 帮助信息. Defaults to "".
@@ -35,7 +35,7 @@ def operator_class(
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
assert issubclass(cls, CommandOperator)
cls.name = name
cls.alias = alias
cls.help = help
@@ -96,14 +96,13 @@ class CommandOperator(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令
支持多次yield以返回多个结果。
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
Args:
context (entities.ExecuteContext): 命令执行上下文

View File

@@ -2,49 +2,46 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="cmd",
help='显示命令列表',
usage='!cmd\n!cmd <命令名称>'
)
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
class CmdOperator(operator.CommandOperator):
"""命令列表
"""
"""命令列表"""
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
"""执行"""
if len(context.crt_params) == 0:
reply_str = "当前所有命令: \n\n"
reply_str = '当前所有命令: \n\n'
for cmd in self.ap.cmd_mgr.cmd_list:
if cmd.parent_class is None:
reply_str += f"{cmd.name}: {cmd.help}\n"
reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助"
reply_str += f'{cmd.name}: {cmd.help}\n'
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
yield entities.CommandReturn(text=reply_str.strip())
else:
cmd_name = context.crt_params[0]
cmd = None
for _cmd in self.ap.cmd_mgr.cmd_list:
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (
_cmd.parent_class is None
):
cmd = _cmd
break
if cmd is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
yield entities.CommandReturn(
error=errors.CommandNotFoundError(cmd_name)
)
else:
reply_str = f"{cmd.name}: {cmd.help}\n\n"
reply_str += f"使用方法: \n{cmd.usage}"
reply_str = f'{cmd.name}: {cmd.help}\n\n'
reply_str += f'使用方法: \n{cmd.usage}'
yield entities.CommandReturn(text=reply_str.strip())

View File

@@ -1,62 +1,60 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="del",
help="删除当前会话的历史记录",
usage='!del <序号>\n!del all'
name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all'
)
class DelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
delete_index = 0
if len(context.crt_params) > 0:
try:
delete_index = int(context.crt_params[0])
except:
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
except Exception:
yield entities.CommandReturn(
error=errors.CommandOperationError('索引必须是整数')
)
return
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
return
# 倒序
to_delete_index = len(context.session.conversations)-1-delete_index
if context.session.conversations[to_delete_index] == context.session.using_conversation:
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(
error=errors.CommandOperationError('索引超出范围')
)
return
# 倒序
to_delete_index = len(context.session.conversations) - 1 - delete_index
if (
context.session.conversations[to_delete_index]
== context.session.using_conversation
):
context.session.using_conversation = None
del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f"已删除对话: {delete_index}")
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
@operator.operator_class(
name="all",
help="删除此会话的所有历史记录",
parent_class=DelOperator
name='all', help='删除此会话的所有历史记录', parent_class=DelOperator
)
class DelAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
context.session.conversations = []
context.session.using_conversation = None
yield entities.CommandReturn(text="已删除所有对话")
yield entities.CommandReturn(text='已删除所有对话')

View File

@@ -1,16 +1,15 @@
from __future__ import annotations
from typing import AsyncGenerator
from .. import operator, entities, cmdmgr
from ...plugin import context as plugin_context
from .. import operator, entities
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
class FuncOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前已启用的内容函数: \n\n"
reply_str = '当前已启用的内容函数: \n\n'
index = 1
@@ -19,7 +18,7 @@ class FuncOperator(operator.CommandOperator):
)
for func in all_functions:
reply_str += "{}. {}:\n{}\n\n".format(
reply_str += '{}. {}:\n{}\n\n'.format(
index,
func.name,
func.description,

View File

@@ -2,19 +2,13 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities
@operator.operator_class(
name='help',
help='显示帮助',
usage='!help\n!help <命令名称>'
)
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
class HelpOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接https://langbot.app'

View File

@@ -1,36 +1,43 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="last",
help="切换到前一个对话",
usage='!last'
)
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
class LastOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的上一个会话
for index in range(len(context.session.conversations)-1, -1, -1):
if context.session.conversations[index] == context.session.using_conversation:
for index in range(len(context.session.conversations) - 1, -1, -1):
if (
context.session.conversations[index]
== context.session.using_conversation
):
if index == 0:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
yield entities.CommandReturn(
error=errors.CommandOperationError('已经是第一个对话了')
)
return
else:
context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
context.session.using_conversation = (
context.session.conversations[index - 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
yield entities.CommandReturn(
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
)
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)

View File

@@ -1,30 +1,26 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="list",
help="列出此会话中的所有历史对话",
usage='!list\n!list <页码>'
name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>'
)
class ListOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
page = 0
if len(context.crt_params) > 0:
try:
page = int(context.crt_params[0]-1)
except:
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
page = int(context.crt_params[0] - 1)
except Exception:
yield entities.CommandReturn(
error=errors.CommandOperationError('页码应为整数')
)
return
record_per_page = 10
@@ -36,21 +32,21 @@ class ListOperator(operator.CommandOperator):
using_conv_index = 0
for conv in context.session.conversations[::-1]:
time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S")
time_str = conv.create_time.strftime('%Y-%m-%d %H:%M:%S')
if conv == context.session.using_conversation:
using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n"
content += f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
index += 1
if content == '':
content = ''
else:
if context.session.using_conversation is None:
content += "\n当前处于新会话"
content += '\n当前处于新会话'
else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}"
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}")
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
yield entities.CommandReturn(text=f'{page + 1} 页 (时间倒序):\n{content}')

View File

@@ -2,42 +2,44 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="model",
name='model',
help='显示和切换模型列表',
usage='!model\n!model show <模型名>\n!model set <模型名>',
privilege=2
privilege=2,
)
class ModelOperator(operator.CommandOperator):
"""Model命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content = '模型列表:\n'
model_list = self.ap.model_mgr.model_list
for model in model_list:
content += f"\n名称: {model.name}\n"
content += f"请求器: {model.requester.name}\n"
content += f'\n名称: {model.name}\n'
content += f'请求器: {model.requester.name}\n'
content += f"\n当前对话使用模型: {context.query.use_model.name}\n"
content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n"
content += f'\n当前对话使用模型: {context.query.use_model.name}\n'
content += f'新对话默认使用模型: {self.ap.provider_cfg.data.get("model")}\n'
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name="show",
help='显示模型详情',
privilege=2,
parent_class=ModelOperator
name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator
)
class ModelShowOperator(operator.CommandOperator):
"""Model Show命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
@@ -47,29 +49,31 @@ class ModelShowOperator(operator.CommandOperator):
break
if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
else:
content = f"模型详情\n"
content += f"名称: {model.name}\n"
content = '模型详情\n'
content += f'名称: {model.name}\n'
if model.model_name is not None:
content += f"请求模型名称: {model.model_name}\n"
content += f"请求器: {model.requester.name}\n"
content += f"密钥组: {model.token_mgr.name}\n"
content += f"支持视觉: {model.vision_supported}\n"
content += f"支持工具: {model.tool_call_supported}\n"
content += f'请求模型名称: {model.model_name}\n'
content += f'请求器: {model.requester.name}\n'
content += f'密钥组: {model.token_mgr.name}\n'
content += f'支持视觉: {model.vision_supported}\n'
content += f'支持工具: {model.tool_call_supported}\n'
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name="set",
help='设置默认使用模型',
privilege=2,
parent_class=ModelOperator
name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator
)
class ModelSetOperator(operator.CommandOperator):
"""Model Set命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
@@ -79,8 +83,12 @@ class ModelSetOperator(operator.CommandOperator):
break
if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
else:
self.ap.provider_cfg.data['model'] = model_name
await self.ap.provider_cfg.dump_config()
yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效")
yield entities.CommandReturn(
text=f'已设置当前使用模型为 {model_name},重置会话以生效'
)

View File

@@ -1,35 +1,42 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="next",
help="切换到后一个对话",
usage='!next'
)
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
class NextOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的下一个会话
for index in range(len(context.session.conversations)):
if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations)-1:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
if (
context.session.conversations[index]
== context.session.using_conversation
):
if index == len(context.session.conversations) - 1:
yield entities.CommandReturn(
error=errors.CommandOperationError('已经是最后一个对话了')
)
return
else:
context.session.using_conversation = context.session.conversations[index+1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
context.session.using_conversation = (
context.session.conversations[index + 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
yield entities.CommandReturn(
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
)
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)

View File

@@ -2,31 +2,32 @@ from __future__ import annotations
import json
import typing
import traceback
import ollama
from .. import operator, entities, errors
@operator.operator_class(
name="ollama",
help="ollama平台操作",
usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>"
name='ollama',
help='ollama平台操作',
usage='!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>',
)
class OllamaOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
content: str = '模型列表:\n'
model_list: list = ollama.list().get('models', [])
for model in model_list:
content += f"名称: {model['name']}\n"
content += f"修改时间: {model['modified_at']}\n"
content += f"大小: {bytes_to_mb(model['size'])}MB\n\n"
yield entities.CommandReturn(text=f"{content.strip()}")
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
content += f'名称: {model["name"]}\n'
content += f'修改时间: {model["modified_at"]}\n'
content += f'大小: {bytes_to_mb(model["size"])}MB\n\n'
yield entities.CommandReturn(text=f'{content.strip()}')
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
def bytes_to_mb(num_bytes):
@@ -35,14 +36,11 @@ def bytes_to_mb(num_bytes):
@operator.operator_class(
name="show",
help="ollama模型详情",
privilege=2,
parent_class=OllamaOperator
name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator
)
class OllamaShowOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content: str = '模型详情:\n'
try:
@@ -53,31 +51,36 @@ class OllamaShowOperator(operator.CommandOperator):
for key in ['license', 'modelfile']:
show[key] = ignore_show
for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']:
for key in [
'tokenizer.chat_template.rag',
'tokenizer.chat_template.tool_use',
]:
model_info[key] = ignore_show
content += json.dumps(show, indent=4)
yield entities.CommandReturn(text=content.strip())
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常"))
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常')
)
@operator.operator_class(
name="pull",
help="ollama模型拉取",
privilege=2,
parent_class=OllamaOperator
name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator
)
class OllamaPullOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
model_list: list = ollama.list().get('models', [])
if context.crt_params[0] in [model['name'] for model in model_list]:
yield entities.CommandReturn(text="模型已存在")
yield entities.CommandReturn(text='模型已存在')
return
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
return
on_progress: bool = False
@@ -99,23 +102,21 @@ class OllamaPullOperator(operator.CommandOperator):
if percentage_completed > progress_count:
progress_count += 10
yield entities.CommandReturn(
text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)")
text=f'下载进度: {completed}/{total} ({percentage_completed:.2f}%)'
)
except ollama.ResponseError as e:
yield entities.CommandReturn(text=f"拉取失败: {e.error}")
yield entities.CommandReturn(text=f'拉取失败: {e.error}')
@operator.operator_class(
name="del",
help="ollama模型删除",
privilege=2,
parent_class=OllamaOperator
name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator
)
class OllamaDelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
ret: str = ollama.delete(model=context.crt_params[0])['status']
except ollama.ResponseError as e:
ret = f"{e.error}"
ret = f'{e.error}'
yield entities.CommandReturn(text=ret)

View File

@@ -2,31 +2,30 @@ from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from ...core import app
from .. import operator, entities, errors
@operator.operator_class(
name="plugin",
help="插件操作",
usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>"
name='plugin',
help='插件操作',
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
)
class PluginOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins()
reply_str = "所有插件({}):\n".format(len(plugin_list))
reply_str = '所有插件({}):\n'.format(len(plugin_list))
idx = 0
for plugin in plugin_list:
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin.plugin_name,
"[已禁用]" if not plugin.enabled else "",
plugin.plugin_description,
plugin.plugin_version, plugin.plugin_author)
reply_str += '\n#{} {} {}\n{}\nv{}\n作者: {}\n'.format(
(idx + 1),
plugin.plugin_name,
'[已禁用]' if not plugin.enabled else '',
plugin.plugin_description,
plugin.plugin_version,
plugin.plugin_author,
)
idx += 1
@@ -34,48 +33,42 @@ class PluginOperator(operator.CommandOperator):
@operator.operator_class(
name="get",
help="安装插件",
privilege=2,
parent_class=PluginOperator
name='get', help='安装插件', privilege=2, parent_class=PluginOperator
)
class PluginGetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件仓库地址')
)
else:
repo = context.crt_params[0]
yield entities.CommandReturn(text="正在安装插件...")
yield entities.CommandReturn(text='正在安装插件...')
try:
await self.ap.plugin_mgr.install_plugin(repo)
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件")
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件安装失败: ' + str(e))
)
@operator.operator_class(
name="update",
help="更新插件",
privilege=2,
parent_class=PluginOperator
name='update', help='更新插件', privilege=2, parent_class=PluginOperator
)
class PluginUpdateOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else:
plugin_name = context.crt_params[0]
@@ -83,36 +76,34 @@ class PluginUpdateOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None:
yield entities.CommandReturn(text="正在更新插件...")
yield entities.CommandReturn(text='正在更新插件...')
await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件")
yield entities.CommandReturn(
text='插件更新成功,请重启程序以加载插件'
)
else:
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件"))
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: 未找到插件')
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
@operator.operator_class(
name="all",
help="更新所有插件",
privilege=2,
parent_class=PluginUpdateOperator
name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator
)
class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
plugins = [
p.plugin_name
for p in self.ap.plugin_mgr.plugins()
]
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
if plugins:
yield entities.CommandReturn(text="正在更新插件...")
yield entities.CommandReturn(text='正在更新插件...')
updated = []
try:
for plugin_name in plugins:
@@ -120,30 +111,32 @@ class PluginUpdateAllOperator(operator.CommandOperator):
updated.append(plugin_name)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated)))
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
yield entities.CommandReturn(
text='已更新插件: {}'.format(', '.join(updated))
)
else:
yield entities.CommandReturn(text="没有可更新的插件")
yield entities.CommandReturn(text='没有可更新的插件')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
@operator.operator_class(
name="del",
help="删除插件",
privilege=2,
parent_class=PluginOperator
name='del', help='删除插件', privilege=2, parent_class=PluginOperator
)
class PluginDelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else:
plugin_name = context.crt_params[0]
@@ -151,67 +144,81 @@ class PluginDelOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None:
yield entities.CommandReturn(text="正在删除插件...")
yield entities.CommandReturn(text='正在删除插件...')
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件")
yield entities.CommandReturn(
text='插件删除成功,请重启程序以加载插件'
)
else:
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件"))
yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: 未找到插件')
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: ' + str(e))
)
@operator.operator_class(
name="on",
help="启用插件",
privilege=2,
parent_class=PluginOperator
name='on', help='启用插件', privilege=2, parent_class=PluginOperator
)
class PluginEnableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
yield entities.CommandReturn(
text='已启用插件: {}'.format(plugin_name)
)
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)
@operator.operator_class(
name="off",
help="禁用插件",
privilege=2,
parent_class=PluginOperator
name='off', help='禁用插件', privilege=2, parent_class=PluginOperator
)
class PluginDisableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
yield entities.CommandReturn(
text='已禁用插件: {}'.format(plugin_name)
)
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)

View File

@@ -2,28 +2,23 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="prompt",
help="查看当前对话的前文",
usage='!prompt'
)
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
class PromptOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
"""执行"""
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
else:
reply_str = '当前对话所有内容:\n\n'
for msg in context.session.using_conversation.messages:
reply_str += f"{msg.role}: {msg.content}\n"
reply_str += f'{msg.role}: {msg.content}\n'
yield entities.CommandReturn(text=reply_str)
yield entities.CommandReturn(text=reply_str)

View File

@@ -2,26 +2,22 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="resend",
help="重发当前会话的最后一条消息",
usage='!resend'
name='resend', help='重发当前会话的最后一条消息', usage='!resend'
)
class ResendOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
# 回滚到最后一条用户message前
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError("当前没有对话"))
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))
else:
conv_msg = context.session.using_conversation.messages
# 倒序一直删到最后一条用户message
while len(conv_msg) > 0 and conv_msg[-1].role != 'user':
conv_msg.pop()
@@ -31,4 +27,4 @@ class ResendOperator(operator.CommandOperator):
conv_msg.pop()
# 不重发了,提示用户已删除就行了
yield entities.CommandReturn(text="已删除最后一次请求记录")
yield entities.CommandReturn(text='已删除最后一次请求记录')

View File

@@ -2,22 +2,15 @@ from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities
@operator.operator_class(
name="reset",
help="重置当前会话",
usage='!reset'
)
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
class ResetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
"""执行"""
context.session.using_conversation = None
yield entities.CommandReturn(text="已重置当前会话")
yield entities.CommandReturn(text='已重置当前会话')

View File

@@ -3,28 +3,22 @@ from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from .. import operator, entities, errors
@operator.operator_class(
name="update",
help="更新程序",
usage='!update',
privilege=2
)
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
class UpdateCommand(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
yield entities.CommandReturn(text="正在进行更新...")
yield entities.CommandReturn(text='正在进行更新...')
if await self.ap.ver_mgr.update_all():
yield entities.CommandReturn(text="更新完成,请重启程序以应用更新")
yield entities.CommandReturn(text='更新完成,请重启程序以应用更新')
else:
yield entities.CommandReturn(text="当前已是最新版本")
yield entities.CommandReturn(text='当前已是最新版本')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e)))
yield entities.CommandReturn(
error=errors.CommandError('更新失败: ' + str(e))
)

View File

@@ -2,26 +2,20 @@ from __future__ import annotations
import typing
from .. import operator, cmdmgr, entities, errors
from .. import operator, entities
@operator.operator_class(
name="version",
help="显示版本信息",
usage='!version'
)
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
class VersionCommand(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f"当前版本: \n{self.ap.ver_mgr.get_current_version()}"
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
try:
if await self.ap.ver_mgr.is_new_version_available():
reply_str += "\n\n有新版本可用。"
except:
reply_str += '\n\n有新版本可用。'
except Exception:
pass
yield entities.CommandReturn(text=reply_str.strip())
yield entities.CommandReturn(text=reply_str.strip())

View File

@@ -9,7 +9,10 @@ class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件"""
def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
self,
config_file_name: str,
template_file_name: str = None,
template_data: dict = None,
) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
@@ -22,28 +25,26 @@ class JSONConfigFile(file_model.ConfigFile):
if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(self.template_data, f, indent=4, ensure_ascii=False)
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self, completion: bool=True) -> dict:
raise ValueError('template_file_name or template_data must be provided')
async def load(self, completion: bool = True) -> dict:
if not self.exists():
await self.create()
if self.template_file_name is not None:
with open(self.template_file_name, "r", encoding="utf-8") as f:
with open(self.template_file_name, 'r', encoding='utf-8') as f:
self.template_data = json.load(f)
with open(self.config_file_name, "r", encoding="utf-8") as f:
with open(self.config_file_name, 'r', encoding='utf-8') as f:
try:
cfg = json.load(f)
except json.JSONDecodeError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}')
if completion:
for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
@@ -51,9 +52,9 @@ class JSONConfigFile(file_model.ConfigFile):
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)
def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)

View File

@@ -25,10 +25,10 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self, completion: bool=True) -> dict:
async def load(self, completion: bool = True) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name)
cfg = {}
allowed_types = (int, float, str, bool, list, dict)
@@ -63,4 +63,4 @@ class PythonModuleConfigFile(file_model.ConfigFile):
logging.warning('Python模块配置文件不支持保存')
def save_sync(self, data: dict):
logging.warning('Python模块配置文件不支持保存')
logging.warning('Python模块配置文件不支持保存')

View File

@@ -9,7 +9,10 @@ class YAMLConfigFile(file_model.ConfigFile):
"""YAML配置文件"""
def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
self,
config_file_name: str,
template_file_name: str = None,
template_data: dict = None,
) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
@@ -22,28 +25,26 @@ class YAMLConfigFile(file_model.ConfigFile):
if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
yaml.dump(self.template_data, f, indent=4, allow_unicode=True)
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self, completion: bool=True) -> dict:
raise ValueError('template_file_name or template_data must be provided')
async def load(self, completion: bool = True) -> dict:
if not self.exists():
await self.create()
if self.template_file_name is not None:
with open(self.template_file_name, "r", encoding="utf-8") as f:
with open(self.template_file_name, 'r', encoding='utf-8') as f:
self.template_data = yaml.load(f, Loader=yaml.FullLoader)
with open(self.config_file_name, "r", encoding="utf-8") as f:
with open(self.config_file_name, 'r', encoding='utf-8') as f:
try:
cfg = yaml.load(f, Loader=yaml.FullLoader)
except yaml.YAMLError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}')
if completion:
for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
@@ -51,9 +52,9 @@ class YAMLConfigFile(file_model.ConfigFile):
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
with open(self.config_file_name, 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)
def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)
with open(self.config_file_name, 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)

View File

@@ -6,7 +6,7 @@ from .impls import pymodule, json as json_file, yaml as yaml_file
class ConfigManager:
"""配置文件管理器"""
name: str = None
"""配置管理器名"""
@@ -31,7 +31,7 @@ class ConfigManager:
self.file = cfg_file
self.data = {}
async def load_config(self, completion: bool=True):
async def load_config(self, completion: bool = True):
self.data = await self.file.load(completion=completion)
async def dump_config(self):
@@ -41,9 +41,11 @@ class ConfigManager:
self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
async def load_python_module_config(
config_name: str, template_name: str, completion: bool = True
) -> ConfigManager:
"""加载Python模块配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
@@ -52,10 +54,7 @@ async def load_python_module_config(config_name: str, template_name: str, comple
Returns:
ConfigManager: 配置文件管理器
"""
cfg_inst = pymodule.PythonModuleConfigFile(
config_name,
template_name
)
cfg_inst = pymodule.PythonModuleConfigFile(config_name, template_name)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion)
@@ -63,20 +62,21 @@ async def load_python_module_config(config_name: str, template_name: str, comple
return cfg_mgr
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
async def load_json_config(
config_name: str,
template_name: str = None,
template_data: dict = None,
completion: bool = True,
) -> ConfigManager:
"""加载JSON配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
template_data (dict): 模板数据
completion (bool): 是否自动补全内存中的配置文件
"""
cfg_inst = json_file.JSONConfigFile(
config_name,
template_name,
template_data
)
cfg_inst = json_file.JSONConfigFile(config_name, template_name, template_data)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion)
@@ -84,9 +84,14 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
return cfg_mgr
async def load_yaml_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
async def load_yaml_config(
config_name: str,
template_name: str = None,
template_data: dict = None,
completion: bool = True,
) -> ConfigManager:
"""加载YAML配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
@@ -96,11 +101,7 @@ async def load_yaml_config(config_name: str, template_name: str=None, template_d
Returns:
ConfigManager: 配置文件管理器
"""
cfg_inst = yaml_file.YAMLConfigFile(
config_name,
template_name,
template_data
)
cfg_inst = yaml_file.YAMLConfigFile(config_name, template_name, template_data)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion)

View File

@@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def load(self, completion: bool=True) -> dict:
async def load(self, completion: bool = True) -> dict:
pass
@abc.abstractmethod

View File

@@ -2,9 +2,7 @@ from __future__ import annotations
import logging
import asyncio
import threading
import traceback
import enum
import sys
import os
@@ -29,7 +27,6 @@ from ..discover import engine as discover_engine
from ..utils import logcache, ip
from . import taskmgr
from . import entities as core_entities
from .bootutils import config
class Application:
@@ -123,33 +120,55 @@ class Application:
async def run(self):
try:
await self.plugin_mgr.initialize_plugins()
# 后续可能会允许动态重启其他任务
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
async def never_ending():
while True:
await asyncio.sleep(1)
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(
self.platform_mgr.run(),
name='platform-manager',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
self.task_mgr.create_task(
self.ctrl.run(),
name='query-controller',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
self.task_mgr.create_task(
self.http_ctrl.run(),
name='http-api-controller',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
self.task_mgr.create_task(
never_ending(),
name='never-ending-task',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
await self.print_web_access_info()
await self.task_mgr.wait_all()
except asyncio.CancelledError:
pass
except Exception as e:
self.logger.error(f"应用运行致命异常: {e}")
self.logger.debug(f"Traceback: {traceback.format_exc()}")
self.logger.error(f'应用运行致命异常: {e}')
self.logger.debug(f'Traceback: {traceback.format_exc()}')
async def print_web_access_info(self):
"""打印访问 webui 的提示"""
if not os.path.exists(os.path.join(".", "web/out")):
self.logger.warning("WebUI 文件缺失请根据文档获取https://docs.langbot.app/webui/intro.html")
if not os.path.exists(os.path.join('.', 'web/out')):
self.logger.warning(
'WebUI 文件缺失请根据文档获取https://docs.langbot.app/webui/intro.html'
)
return
host_ip = "127.0.0.1"
host_ip = '127.0.0.1'
public_ip = await ip.get_myip()
@@ -170,7 +189,7 @@ class Application:
🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues
=======================================
""".strip()
for line in tips.split("\n"):
for line in tips.split('\n'):
self.logger.info(line)
async def reload(
@@ -179,21 +198,28 @@ class Application:
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info("执行热重载 scope="+scope)
self.logger.info('执行热重载 scope=' + scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
self.task_mgr.create_task(
self.platform_mgr.run(),
name='platform-manager',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
case core_entities.LifecycleControlScope.PLUGIN.value:
self.logger.info("执行热重载 scope="+scope)
self.logger.info('执行热重载 scope=' + scope)
await self.plugin_mgr.destroy_plugins()
# 删除 sys.module 中所有的 plugins/* 下的模块
for mod in list(sys.modules.keys()):
if mod.startswith("plugins."):
if mod.startswith('plugins.'):
del sys.modules[mod]
self.plugin_mgr = plugin_mgr.PluginManager(self)
@@ -204,7 +230,7 @@ class Application:
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
case core_entities.LifecycleControlScope.PROVIDER.value:
self.logger.info("执行热重载 scope="+scope)
self.logger.info('执行热重载 scope=' + scope)
await self.tool_mgr.shutdown()
@@ -220,4 +246,4 @@ class Application:
await llm_tool_mgr_inst.initialize()
self.tool_mgr = llm_tool_mgr_inst
case _:
pass
pass

View File

@@ -7,29 +7,30 @@ import os
from . import app
from ..audit import identifier
from . import stage
from ..utils import constants
from ..utils import constants, importutil
# 引入启动阶段实现以便注册
from .stages import load_config, setup_logger, build_app, migrate, show_notes, genkeys
from . import stages
importutil.import_modules_in_pkg(stages)
stage_order = [
"LoadConfigStage",
"MigrationStage",
"GenKeysStage",
"SetupLoggerStage",
"BuildAppStage",
"ShowNotesStage"
'LoadConfigStage',
'MigrationStage',
'GenKeysStage',
'SetupLoggerStage',
'BuildAppStage',
'ShowNotesStage',
]
async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
# 生成标识符
identifier.init()
# 确定是否为调试模式
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']:
constants.debug_mode = True
ap = app.Application()
@@ -50,21 +51,17 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
async def main(loop: asyncio.AbstractEventLoop):
try:
# 挂系统信号处理
import signal
ap: app.Application
def signal_handler(sig, frame):
print("[Signal] 程序退出.")
print('[Signal] 程序退出.')
# ap.shutdown()
os._exit(0)
signal.signal(signal.SIGINT, signal_handler)
app_inst = await make_app(loop)
ap = app_inst
await app_inst.run()
except Exception as e:
except Exception:
traceback.print_exc()

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
import json
from ...config import manager as config_mgr
from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config
load_json_config = config_mgr.load_json_config
load_yaml_config = config_mgr.load_yaml_config
load_yaml_config = config_mgr.load_yaml_config

View File

@@ -5,39 +5,39 @@ from ...utils import pkgmgr
# 检查依赖,防止用户未安装
# 左边为引入名称,右边为依赖名称
required_deps = {
"requests": "requests",
"openai": "openai",
"anthropic": "anthropic",
"colorlog": "colorlog",
"aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy-rc",
"PIL": "pillow",
"nakuru": "nakuru-project-idk",
"tiktoken": "tiktoken",
"yaml": "pyyaml",
"aiohttp": "aiohttp",
"psutil": "psutil",
"async_lru": "async-lru",
"ollama": "ollama",
"quart": "quart",
"quart_cors": "quart-cors",
"sqlalchemy": "sqlalchemy[asyncio]",
"aiosqlite": "aiosqlite",
"aiofiles": "aiofiles",
"aioshutil": "aioshutil",
"argon2": "argon2-cffi",
"jwt": "pyjwt",
"Crypto": "pycryptodome",
"lark_oapi": "lark-oapi",
"discord": "discord.py",
"cryptography": "cryptography",
"gewechat_client": "gewechat-client",
"dingtalk_stream": "dingtalk_stream",
"dashscope": "dashscope",
"telegram": "python-telegram-bot",
"certifi": "certifi",
"mcp": "mcp",
"sqlmodel": "sqlmodel",
'requests': 'requests',
'openai': 'openai',
'anthropic': 'anthropic',
'colorlog': 'colorlog',
'aiocqhttp': 'aiocqhttp',
'botpy': 'qq-botpy-rc',
'PIL': 'pillow',
'nakuru': 'nakuru-project-idk',
'tiktoken': 'tiktoken',
'yaml': 'pyyaml',
'aiohttp': 'aiohttp',
'psutil': 'psutil',
'async_lru': 'async-lru',
'ollama': 'ollama',
'quart': 'quart',
'quart_cors': 'quart-cors',
'sqlalchemy': 'sqlalchemy[asyncio]',
'aiosqlite': 'aiosqlite',
'aiofiles': 'aiofiles',
'aioshutil': 'aioshutil',
'argon2': 'argon2-cffi',
'jwt': 'pyjwt',
'Crypto': 'pycryptodome',
'lark_oapi': 'lark-oapi',
'discord': 'discord.py',
'cryptography': 'cryptography',
'gewechat_client': 'gewechat-client',
'dingtalk_stream': 'dingtalk_stream',
'dashscope': 'dashscope',
'telegram': 'python-telegram-bot',
'certifi': 'certifi',
'mcp': 'mcp',
'sqlmodel': 'sqlmodel',
}
@@ -52,20 +52,25 @@ async def check_deps() -> list[str]:
missing_deps.append(dep)
return missing_deps
async def install_deps(deps: list[str]):
global required_deps
for dep in deps:
pip.main(["install", required_deps[dep]])
pip.main(['install', required_deps[dep]])
async def precheck_plugin_deps():
print('[Startup] Prechecking plugin dependencies...')
# 只有在plugins目录存在时才执行插件依赖安装
if os.path.exists("plugins"):
for dir in os.listdir("plugins"):
subdir = os.path.join("plugins", dir)
if os.path.exists('plugins'):
for dir in os.listdir('plugins'):
subdir = os.path.join('plugins', dir)
if not os.path.isdir(subdir):
continue
if 'requirements.txt' in os.listdir(subdir):
pkgmgr.install_requirements(os.path.join(subdir, 'requirements.txt'), extra_params=['-q', '-q', '-q'])
pkgmgr.install_requirements(
os.path.join(subdir, 'requirements.txt'),
extra_params=['-q', '-q', '-q'],
)

View File

@@ -2,23 +2,23 @@ from __future__ import annotations
import os
import shutil
import sys
required_files = {
"plugins/__init__.py": "templates/__init__.py",
"data/config.yaml": "templates/config.yaml",
'plugins/__init__.py': 'templates/__init__.py',
'data/config.yaml': 'templates/config.yaml',
}
required_paths = [
"temp",
"data",
"data/metadata",
"data/logs",
"data/labels",
"plugins"
'temp',
'data',
'data/metadata',
'data/logs',
'data/labels',
'plugins',
]
async def generate_files() -> list[str]:
global required_files, required_paths

View File

@@ -1,5 +1,4 @@
import logging
import os
import sys
import time
@@ -9,11 +8,11 @@ from ...utils import constants
log_colors_config = {
"DEBUG": "green", # cyan white
"INFO": "white",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "cyan",
'DEBUG': 'green', # cyan white
'INFO': 'white',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'cyan',
}
@@ -27,26 +26,31 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
if constants.debug_mode:
level = logging.DEBUG
log_file_name = "data/logs/langbot-%s.log" % time.strftime(
"%Y-%m-%d", time.localtime()
log_file_name = 'data/logs/langbot-%s.log' % time.strftime(
'%Y-%m-%d', time.localtime()
)
qcg_logger = logging.getLogger("langbot")
qcg_logger = logging.getLogger('langbot')
qcg_logger.setLevel(level)
color_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s",
datefmt="%m-%d %H:%M:%S",
fmt='%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s',
datefmt='%m-%d %H:%M:%S',
log_colors=log_colors_config,
)
stream_handler = logging.StreamHandler(sys.stdout)
# stream_handler.setLevel(level)
# stream_handler.setFormatter(color_formatter)
stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
stream_handler.stream = open(
sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1
)
log_handlers: list[logging.Handler] = [stream_handler, logging.FileHandler(log_file_name, encoding='utf-8')]
log_handlers: list[logging.Handler] = [
stream_handler,
logging.FileHandler(log_file_name, encoding='utf-8'),
]
log_handlers += extra_handlers if extra_handlers is not None else []
for handler in log_handlers:
@@ -54,13 +58,13 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
handler.setFormatter(color_formatter)
qcg_logger.addHandler(handler)
qcg_logger.debug("日志初始化完成,日志级别:%s" % level)
qcg_logger.debug('日志初始化完成,日志级别:%s' % level)
logging.basicConfig(
level=logging.CRITICAL, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
format='[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s',
# 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
datefmt='%Y-%m-%d %H:%M:%S', # 时间输出的格式
handlers=[logging.NullHandler()],
)

View File

@@ -8,21 +8,18 @@ import asyncio
import pydantic.v1 as pydantic
from ..provider import entities as llm_entities
from ..provider.modelmgr import entities, modelmgr, requester
from ..provider.modelmgr import requester
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
from ..platform.types import entities as platform_entities
class LifecycleControlScope(enum.Enum):
APPLICATION = "application"
PLATFORM = "platform"
PLUGIN = "plugin"
PROVIDER = "provider"
APPLICATION = 'application'
PLATFORM = 'platform'
PLUGIN = 'plugin'
PROVIDER = 'provider'
class LauncherTypes(enum.Enum):
@@ -89,14 +86,17 @@ class Query(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = []
resp_messages: (
typing.Optional[list[llm_entities.Message]]
| typing.Optional[list[platform_message.MessageChain]]
) = []
"""由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
"""回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None
current_stage = None # pkg.pipeline.pipelinemgr.StageInstContainer
"""当前所处阶段"""
class Config:
@@ -109,13 +109,13 @@ class Query(pydantic.BaseModel):
if self.variables is None:
self.variables = {}
self.variables[key] = value
def get_variable(self, key: str) -> typing.Any:
"""获取变量"""
if self.variables is None:
return None
return self.variables.get(key)
def get_variables(self) -> dict[str, typing.Any]:
"""获取所有变量"""
if self.variables is None:
@@ -130,9 +130,13 @@ class Conversation(pydantic.BaseModel):
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
use_llm_model: requester.RuntimeLLMModel
@@ -147,6 +151,7 @@ class Conversation(pydantic.BaseModel):
class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes
launcher_id: typing.Union[int, str]
@@ -157,11 +162,17 @@ class Session(pydantic.BaseModel):
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
conversations: typing.Optional[list[Conversation]] = pydantic.Field(
default_factory=list
)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""

View File

@@ -9,21 +9,21 @@ from . import app
preregistered_migrations: list[typing.Type[Migration]] = []
"""当前阶段暂不支持扩展"""
def migration_class(name: str, number: int):
"""注册一个迁移
"""
"""注册一个迁移"""
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
cls.name = name
cls.number = number
preregistered_migrations.append(cls)
return cls
return decorator
class Migration(abc.ABC):
"""一个版本的迁移
"""
"""一个版本的迁移"""
name: str
@@ -33,15 +33,13 @@ class Migration(abc.ABC):
def __init__(self, ap: app.Application):
self.ap = ap
@abc.abstractmethod
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
"""判断当前环境是否需要运行此迁移"""
pass
@abc.abstractmethod
async def run(self):
"""执行迁移
"""
"""执行迁移"""
pass

View File

@@ -1,26 +1,26 @@
from __future__ import annotations
import os
import sys
from .. import migration
@migration.migration_class("sensitive-word-migration", 1)
@migration.migration_class('sensitive-word-migration', 1)
class SensitiveWordMigration(migration.Migration):
"""敏感词迁移
"""
"""敏感词迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return os.path.exists("data/config/sensitive-words.json") and not os.path.exists("data/metadata/sensitive-words.json")
"""判断当前环境是否需要运行此迁移"""
return os.path.exists(
'data/config/sensitive-words.json'
) and not os.path.exists('data/metadata/sensitive-words.json')
async def run(self):
"""执行迁移
"""
"""执行迁移"""
# 移动文件
os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json")
os.rename(
'data/config/sensitive-words.json', 'data/metadata/sensitive-words.json'
)
# 重新加载配置
await self.ap.sensitive_meta.load_config()

View File

@@ -3,19 +3,16 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("openai-config-migration", 2)
@migration.migration_class('openai-config-migration', 2)
class OpenAIConfigMigration(migration.Migration):
"""OpenAI配置迁移
"""
"""OpenAI配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
"""判断当前环境是否需要运行此迁移"""
return 'openai-config' in self.ap.provider_cfg.data
async def run(self):
"""执行迁移
"""
"""执行迁移"""
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
if 'keys' not in self.ap.provider_cfg.data:
@@ -26,7 +23,9 @@ class OpenAIConfigMigration(migration.Migration):
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
self.ap.provider_cfg.data['model'] = old_openai_config[
'chat-completions-params'
]['model']
del old_openai_config['chat-completions-params']['model']
@@ -35,7 +34,7 @@ class OpenAIConfigMigration(migration.Migration):
if 'openai-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {}
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {
'base-url': old_openai_config['base_url'],
'args': old_openai_config['chat-completions-params'],
@@ -44,4 +43,4 @@ class OpenAIConfigMigration(migration.Migration):
del self.ap.provider_cfg.data['openai-config']
await self.ap.provider_cfg.dump_config()
await self.ap.provider_cfg.dump_config()

View File

@@ -3,26 +3,23 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("anthropic-requester-config-completion", 3)
@migration.migration_class('anthropic-requester-config-completion', 3)
class AnthropicRequesterConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
"""OpenAI配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'anthropic-messages' not in self.ap.provider_cfg.data['requester']
or 'anthropic' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
'base-url': 'https://api.anthropic.com',
'args': {
'max_tokens': 1024
},
'args': {'max_tokens': 1024},
'timeout': 120,
}

View File

@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("moonshot-config-completion", 4)
@migration.migration_class('moonshot-config-completion', 4)
class MoonshotConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
"""OpenAI配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'moonshot' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
'base-url': 'https://api.moonshot.cn/v1',

View File

@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("deepseek-config-completion", 5)
@migration.migration_class('deepseek-config-completion', 5)
class DeepseekConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
"""OpenAI配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'deepseek' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = {
'base-url': 'https://api.deepseek.com',
@@ -27,4 +26,4 @@ class DeepseekConfigCompletionMigration(migration.Migration):
if 'deepseek' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['deepseek'] = []
await self.ap.provider_cfg.dump_config()
await self.ap.provider_cfg.dump_config()

View File

@@ -3,17 +3,17 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("vision-config", 6)
@migration.migration_class('vision-config', 6)
class VisionConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "enable-vision" not in self.ap.provider_cfg.data
return 'enable-vision' not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
if "enable-vision" not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data["enable-vision"] = False
if 'enable-vision' not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data['enable-vision'] = False
await self.ap.provider_cfg.dump_config()

View File

@@ -3,18 +3,20 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("qcg-center-url-config", 7)
@migration.migration_class('qcg-center-url-config', 7)
class QCGCenterURLConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "qcg-center-url" not in self.ap.system_cfg.data
return 'qcg-center-url' not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""
if "qcg-center-url" not in self.ap.system_cfg.data:
self.ap.system_cfg.data["qcg-center-url"] = "https://api.qchatgpt.rockchin.top/api/v2"
if 'qcg-center-url' not in self.ap.system_cfg.data:
self.ap.system_cfg.data['qcg-center-url'] = (
'https://api.qchatgpt.rockchin.top/api/v2'
)
await self.ap.system_cfg.dump_config()

View File

@@ -3,27 +3,27 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("ad-fixwin-cfg-migration", 8)
@migration.migration_class('ad-fixwin-cfg-migration', 8)
class AdFixwinConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return isinstance(
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]["default"],
int
self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int
)
async def run(self):
"""执行迁移"""
for session_name in self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]:
for session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
temp_dict = {
"window-size": 60,
"limit": self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name]
'window-size': 60,
'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][
session_name
],
}
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] = temp_dict
await self.ap.pipeline_cfg.dump_config()
self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] = temp_dict
await self.ap.pipeline_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("msg-truncator-cfg-migration", 9)
@migration.migration_class('msg-truncator-cfg-migration', 9)
class MsgTruncatorConfigMigration(migration.Migration):
"""迁移"""
@@ -13,12 +13,10 @@ class MsgTruncatorConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.pipeline_cfg.data['msg-truncate'] = {
'method': 'round',
'round': {
'max-round': 10
}
'round': {'max-round': 10},
}
await self.ap.pipeline_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("ollama-requester-config", 10)
@migration.migration_class('ollama-requester-config', 10)
class MsgTruncatorConfigMigration(migration.Migration):
"""迁移"""
@@ -13,11 +13,11 @@ class MsgTruncatorConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['ollama-chat'] = {
"base-url": "http://127.0.0.1:11434",
"args": {},
"timeout": 600
'base-url': 'http://127.0.0.1:11434',
'args': {},
'timeout': 600,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("command-prefix-config", 11)
@migration.migration_class('command-prefix-config', 11)
class CommandPrefixConfigMigration(migration.Migration):
"""迁移"""
@@ -13,9 +13,7 @@ class CommandPrefixConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.command_cfg.data['command-prefix'] = [
"!", ""
]
self.ap.command_cfg.data['command-prefix'] = ['!', '']
await self.ap.command_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("runner-config", 12)
@migration.migration_class('runner-config', 12)
class RunnerConfigMigration(migration.Migration):
"""迁移"""
@@ -13,7 +13,7 @@ class RunnerConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['runner'] = 'local-agent'
await self.ap.provider_cfg.dump_config()

View File

@@ -3,29 +3,30 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("http-api-config", 13)
@migration.migration_class('http-api-config', 13)
class HttpApiConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'http-api' not in self.ap.system_cfg.data or "persistence" not in self.ap.system_cfg.data
return (
'http-api' not in self.ap.system_cfg.data
or 'persistence' not in self.ap.system_cfg.data
)
async def run(self):
"""执行迁移"""
self.ap.system_cfg.data['http-api'] = {
"enable": True,
"host": "0.0.0.0",
"port": 5300,
"jwt-expire": 604800
'enable': True,
'host': '0.0.0.0',
'port': 5300,
'jwt-expire': 604800,
}
self.ap.system_cfg.data['persistence'] = {
"sqlite": {
"path": "data/persistence.db"
},
"use": "sqlite"
'sqlite': {'path': 'data/persistence.db'},
'use': 'sqlite',
}
await self.ap.system_cfg.dump_config()

View File

@@ -3,20 +3,20 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("force-delay-config", 14)
@migration.migration_class('force-delay-config', 14)
class ForceDelayConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return type(self.ap.platform_cfg.data['force-delay']) == list
return isinstance(self.ap.platform_cfg.data['force-delay'], list)
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['force-delay'] = {
"min": self.ap.platform_cfg.data['force-delay'][0],
"max": self.ap.platform_cfg.data['force-delay'][1]
'min': self.ap.platform_cfg.data['force-delay'][0],
'max': self.ap.platform_cfg.data['force-delay'][1],
}
await self.ap.platform_cfg.dump_config()

View File

@@ -3,24 +3,25 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("gitee-ai-config", 15)
@migration.migration_class('gitee-ai-config', 15)
class GiteeAIConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'gitee-ai' not in self.ap.provider_cfg.data['keys']
return (
'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'gitee-ai' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = {
"base-url": "https://ai.gitee.com/v1",
"args": {},
"timeout": 120
'base-url': 'https://ai.gitee.com/v1',
'args': {},
'timeout': 120,
}
self.ap.provider_cfg.data['keys']['gitee-ai'] = [
"XXXXX"
]
self.ap.provider_cfg.data['keys']['gitee-ai'] = ['XXXXX']
await self.ap.provider_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("dify-service-api-config", 16)
@migration.migration_class('dify-service-api-config', 16)
class DifyServiceAPICfgMigration(migration.Migration):
"""迁移"""
@@ -14,15 +14,10 @@ class DifyServiceAPICfgMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['dify-service-api'] = {
"base-url": "https://api.dify.ai/v1",
"app-type": "chat",
"chat": {
"api-key": "app-1234567890"
},
"workflow": {
"api-key": "app-1234567890",
"output-key": "summary"
}
'base-url': 'https://api.dify.ai/v1',
'app-type': 'chat',
'chat': {'api-key': 'app-1234567890'},
'workflow': {'api-key': 'app-1234567890', 'output-key': 'summary'},
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,22 +3,26 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("dify-api-timeout-params", 17)
@migration.migration_class('dify-api-timeout-params', 17)
class DifyAPITimeoutParamsMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow'] \
return (
'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat']
or 'timeout'
not in self.ap.provider_cfg.data['dify-service-api']['workflow']
or 'agent' not in self.ap.provider_cfg.data['dify-service-api']
)
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['dify-service-api']['chat']['timeout'] = 120
self.ap.provider_cfg.data['dify-service-api']['workflow']['timeout'] = 120
self.ap.provider_cfg.data['dify-service-api']['agent'] = {
"api-key": "app-1234567890",
"timeout": 120
'api-key': 'app-1234567890',
'timeout': 120,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("xai-config", 18)
@migration.migration_class('xai-config', 18)
class XaiConfigMigration(migration.Migration):
"""迁移"""
@@ -14,12 +14,10 @@ class XaiConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['xai-chat-completions'] = {
"base-url": "https://api.x.ai/v1",
"args": {},
"timeout": 120
'base-url': 'https://api.x.ai/v1',
'args': {},
'timeout': 120,
}
self.ap.provider_cfg.data['keys']['xai'] = [
"xai-1234567890"
]
self.ap.provider_cfg.data['keys']['xai'] = ['xai-1234567890']
await self.ap.provider_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("zhipuai-config", 19)
@migration.migration_class('zhipuai-config', 19)
class ZhipuaiConfigMigration(migration.Migration):
"""迁移"""
@@ -14,12 +14,10 @@ class ZhipuaiConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['zhipuai-chat-completions'] = {
"base-url": "https://open.bigmodel.cn/api/paas/v4",
"args": {},
"timeout": 120
'base-url': 'https://open.bigmodel.cn/api/paas/v4',
'args': {},
'timeout': 120,
}
self.ap.provider_cfg.data['keys']['zhipuai'] = [
"xxxxxxx"
]
self.ap.provider_cfg.data['keys']['zhipuai'] = ['xxxxxxx']
await self.ap.provider_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("wecom-config", 20)
@migration.migration_class('wecom-config', 20)
class WecomConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'wecom':
# return False
@@ -19,16 +19,18 @@ class WecomConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "wecom",
"enable": False,
"host": "0.0.0.0",
"port": 2290,
"corpid": "",
"secret": "",
"token": "",
"EncodingAESKey": "",
"contacts_secret": ""
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'wecom',
'enable': False,
'host': '0.0.0.0',
'port': 2290,
'corpid': '',
'secret': '',
'token': '',
'EncodingAESKey': '',
'contacts_secret': '',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("lark-config", 21)
@migration.migration_class('lark-config', 21)
class LarkConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'lark':
# return False
@@ -19,15 +19,17 @@ class LarkConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "lark",
"enable": False,
"app_id": "cli_abcdefgh",
"app_secret": "XXXXXXXXXX",
"bot_name": "LangBot",
"enable-webhook": False,
"port": 2285,
"encrypt-key": "xxxxxxxxx"
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'lark',
'enable': False,
'app_id': 'cli_abcdefgh',
'app_secret': 'XXXXXXXXXX',
'bot_name': 'LangBot',
'enable-webhook': False,
'port': 2285,
'encrypt-key': 'xxxxxxxxx',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,21 +3,21 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("lmstudio-config", 22)
@migration.migration_class('lmstudio-config', 22)
class LmStudioConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'lmstudio-chat-completions' not in self.ap.provider_cfg.data['requester']
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['lmstudio-chat-completions'] = {
"base-url": "http://127.0.0.1:1234/v1",
"args": {},
"timeout": 120
'base-url': 'http://127.0.0.1:1234/v1',
'args': {},
'timeout': 120,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,25 +3,25 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("siliconflow-config", 23)
@migration.migration_class('siliconflow-config', 23)
class SiliconFlowConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
return (
'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
)
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['keys']['siliconflow'] = [
"xxxxxxx"
]
self.ap.provider_cfg.data['keys']['siliconflow'] = ['xxxxxxx']
self.ap.provider_cfg.data['requester']['siliconflow-chat-completions'] = {
"base-url": "https://api.siliconflow.cn/v1",
"args": {},
"timeout": 120
'base-url': 'https://api.siliconflow.cn/v1',
'args': {},
'timeout': 120,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("discord-config", 24)
@migration.migration_class('discord-config', 24)
class DiscordConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'discord':
# return False
@@ -19,11 +19,13 @@ class DiscordConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "discord",
"enable": False,
"client_id": "1234567890",
"token": "XXXXXXXXXX"
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'discord',
'enable': False,
'client_id': '1234567890',
'token': 'XXXXXXXXXX',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("gewechat-config", 25)
@migration.migration_class('gewechat-config', 25)
class GewechatConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'gewechat':
# return False
@@ -19,15 +19,17 @@ class GewechatConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "gewechat",
"enable": False,
"gewechat_url": "http://your-gewechat-server:2531",
"gewechat_file_url": "http://your-gewechat-server:2532",
"port": 2286,
"callback_url": "http://your-callback-url:2286/gewechat/callback",
"app_id": "",
"token": ""
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'gewechat',
'enable': False,
'gewechat_url': 'http://your-gewechat-server:2531',
'gewechat_file_url': 'http://your-gewechat-server:2532',
'port': 2286,
'callback_url': 'http://your-callback-url:2286/gewechat/callback',
'app_id': '',
'token': '',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("qqofficial-config", 26)
@migration.migration_class('qqofficial-config', 26)
class QQOfficialConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'qqofficial':
# return False
@@ -19,13 +19,15 @@ class QQOfficialConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "qqofficial",
"enable": False,
"appid": "",
"secret": "",
"port": 2284,
"token": ""
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'qqofficial',
'enable': False,
'appid': '',
'secret': '',
'port': 2284,
'token': '',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("wx-official-account-config", 27)
@migration.migration_class('wx-official-account-config', 27)
class WXOfficialAccountConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'officialaccount':
# return False
@@ -19,15 +19,17 @@ class WXOfficialAccountConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "officialaccount",
"enable": False,
"token": "",
"EncodingAESKey": "",
"AppID": "",
"AppSecret": "",
"host": "0.0.0.0",
"port": 2287
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'officialaccount',
'enable': False,
'token': '',
'EncodingAESKey': '',
'AppID': '',
'AppSecret': '',
'host': '0.0.0.0',
'port': 2287,
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,25 +3,23 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("bailian-requester-config", 28)
@migration.migration_class('bailian-requester-config', 28)
class BailianRequesterConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'bailian-chat-completions' not in self.ap.provider_cfg.data['requester']
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['keys']['bailian'] = [
"sk-xxxxxxx"
]
self.ap.provider_cfg.data['keys']['bailian'] = ['sk-xxxxxxx']
self.ap.provider_cfg.data['requester']['bailian-chat-completions'] = {
"base-url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"args": {},
"timeout": 120
'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'args': {},
'timeout': 120,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("dashscope-app-api-config", 29)
@migration.migration_class('dashscope-app-api-config', 29)
class DashscopeAppAPICfgMigration(migration.Migration):
"""迁移"""
@@ -14,20 +14,14 @@ class DashscopeAppAPICfgMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['dashscope-app-api'] = {
"app-type": "agent",
"api-key": "sk-1234567890",
"agent": {
"app-id": "Your_app_id",
"references_quote": "参考资料来自:"
'app-type': 'agent',
'api-key': 'sk-1234567890',
'agent': {'app-id': 'Your_app_id', 'references_quote': '参考资料来自:'},
'workflow': {
'app-id': 'Your_app_id',
'references_quote': '参考资料来自:',
'biz_params': {'city': '北京', 'date': '2023-08-10'},
},
"workflow": {
"app-id": "Your_app_id",
"references_quote": "参考资料来自:",
"biz_params": {
"city": "北京",
"date": "2023-08-10"
}
}
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("lark-config-cmpl", 30)
@migration.migration_class('lark-config-cmpl', 30)
class LarkConfigCmplMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
for adapter in self.ap.platform_cfg.data['platform-adapters']:
if adapter['adapter'] == 'lark':
if 'enable-webhook' not in adapter:
@@ -26,6 +26,6 @@ class LarkConfigCmplMigration(migration.Migration):
if 'port' not in adapter:
adapter['port'] = 2285
if 'encrypt-key' not in adapter:
adapter['encrypt-key'] = "xxxxxxxxx"
adapter['encrypt-key'] = 'xxxxxxxxx'
await self.ap.platform_cfg.dump_config()

View File

@@ -3,13 +3,13 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("dingtalk-config", 31)
@migration.migration_class('dingtalk-config', 31)
class DingTalkConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
# for adapter in self.ap.platform_cfg.data['platform-adapters']:
# if adapter['adapter'] == 'dingtalk':
# return False
@@ -19,13 +19,15 @@ class DingTalkConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({
"adapter": "dingtalk",
"enable": False,
"client_id": "",
"client_secret": "",
"robot_code": "",
"robot_name": ""
})
self.ap.platform_cfg.data['platform-adapters'].append(
{
'adapter': 'dingtalk',
'enable': False,
'client_id': '',
'client_secret': '',
'robot_code': '',
'robot_name': '',
}
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,25 +3,23 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("volcark-requester-config", 32)
@migration.migration_class('volcark-requester-config', 32)
class VolcArkRequesterConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'volcark-chat-completions' not in self.ap.provider_cfg.data['requester']
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['keys']['volcark'] = [
"xxxxxxxx"
]
self.ap.provider_cfg.data['keys']['volcark'] = ['xxxxxxxx']
self.ap.provider_cfg.data['requester']['volcark-chat-completions'] = {
"base-url": "https://ark.cn-beijing.volces.com/api/v3",
"args": {},
"timeout": 120
'base-url': 'https://ark.cn-beijing.volces.com/api/v3',
'args': {},
'timeout': 120,
}
await self.ap.provider_cfg.dump_config()

View File

@@ -3,24 +3,27 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("dify-thinking-config", 33)
@migration.migration_class('dify-thinking-config', 33)
class DifyThinkingConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
if 'options' not in self.ap.provider_cfg.data["dify-service-api"]:
if 'options' not in self.ap.provider_cfg.data['dify-service-api']:
return True
if 'convert-thinking-tips' not in self.ap.provider_cfg.data["dify-service-api"]["options"]:
if (
'convert-thinking-tips'
not in self.ap.provider_cfg.data['dify-service-api']['options']
):
return True
return False
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data["dify-service-api"]["options"] = {
"convert-thinking-tips": "plain"
self.ap.provider_cfg.data['dify-service-api']['options'] = {
'convert-thinking-tips': 'plain'
}
await self.ap.provider_cfg.dump_config()

View File

@@ -5,7 +5,7 @@ from urllib.parse import urlparse
from .. import migration
@migration.migration_class("gewechat-file-url-config", 34)
@migration.migration_class('gewechat-file-url-config', 34)
class GewechatFileUrlConfigMigration(migration.Migration):
"""迁移"""
@@ -24,6 +24,8 @@ class GewechatFileUrlConfigMigration(migration.Migration):
if adapter['adapter'] == 'gewechat':
if 'gewechat_file_url' not in adapter:
parsed_url = urlparse(adapter['gewechat_url'])
adapter['gewechat_file_url'] = f"{parsed_url.scheme}://{parsed_url.hostname}:2532"
adapter['gewechat_file_url'] = (
f'{parsed_url.scheme}://{parsed_url.hostname}:2532'
)
await self.ap.platform_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("wxoa-mode", 35)
@migration.migration_class('wxoa-mode', 35)
class WxoaModeMigration(migration.Migration):
"""迁移"""

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("wxoa-loading-message", 36)
@migration.migration_class('wxoa-loading-message', 36)
class WxoaLoadingMessageMigration(migration.Migration):
"""迁移"""

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("mcp-config", 37)
@migration.migration_class('mcp-config', 37)
class MCPConfigMigration(migration.Migration):
"""迁移"""
@@ -13,8 +13,6 @@ class MCPConfigMigration(migration.Migration):
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['mcp'] = {
"servers": []
}
self.ap.provider_cfg.data['mcp'] = {'servers': []}
await self.ap.provider_cfg.dump_config()

View File

@@ -7,9 +7,10 @@ from . import app
preregistered_notes: list[typing.Type[LaunchNote]] = []
def note_class(name: str, number: int):
"""注册一个启动信息
"""
"""注册一个启动信息"""
def decorator(cls: typing.Type[LaunchNote]) -> typing.Type[LaunchNote]:
cls.name = name
cls.number = number
@@ -20,8 +21,8 @@ def note_class(name: str, number: int):
class LaunchNote(abc.ABC):
"""启动信息
"""
"""启动信息"""
name: str
number: int
@@ -33,12 +34,10 @@ class LaunchNote(abc.ABC):
@abc.abstractmethod
async def need_show(self) -> bool:
"""判断当前环境是否需要显示此启动信息
"""
"""判断当前环境是否需要显示此启动信息"""
pass
@abc.abstractmethod
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
"""生成启动信息
"""
"""生成启动信息"""
pass

View File

@@ -2,19 +2,17 @@ from __future__ import annotations
import typing
from .. import note, app
from .. import note
@note.note_class("ClassicNotes", 1)
@note.note_class('ClassicNotes', 1)
class ClassicNotes(note.LaunchNote):
"""经典启动信息
"""
"""经典启动信息"""
async def need_show(self) -> bool:
return True
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
yield await self.ap.ann_mgr.show_announcements()
yield await self.ap.ver_mgr.show_version_update()
yield await self.ap.ver_mgr.show_version_update()

View File

@@ -2,20 +2,20 @@ from __future__ import annotations
import typing
import os
import sys
import logging
from .. import note, app
from .. import note
@note.note_class("SelectionModeOnWindows", 2)
@note.note_class('SelectionModeOnWindows', 2)
class SelectionModeOnWindows(note.LaunchNote):
"""Windows 上的选择模式提示信息
"""
"""Windows 上的选择模式提示信息"""
async def need_show(self) -> bool:
return os.name == 'nt'
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
yield """您正在使用 Windows 系统,若窗口左上角显示处于”选择“模式,程序将被暂停运行,此时请右键窗口中空白区域退出选择模式。""", logging.INFO
yield (
"""您正在使用 Windows 系统,若窗口左上角显示处于”选择“模式,程序将被暂停运行,此时请右键窗口中空白区域退出选择模式。""",
logging.INFO,
)

View File

@@ -1,21 +1,17 @@
from __future__ import annotations
import typing
import os
import sys
import logging
from .. import note, app
from .. import note
@note.note_class("PrintVersion", 3)
@note.note_class('PrintVersion', 3)
class PrintVersion(note.LaunchNote):
"""Print Version Information
"""
"""Print Version Information"""
async def need_show(self) -> bool:
return True
async def yield_note(self) -> typing.AsyncGenerator[typing.Tuple[str, int], None]:
yield f"Current Version: {self.ap.ver_mgr.get_current_version()}", logging.INFO
yield f'Current Version: {self.ap.ver_mgr.get_current_version()}', logging.INFO

View File

@@ -12,9 +12,8 @@ preregistered_stages: dict[str, typing.Type[BootingStage]] = {}
当前阶段暂不支持扩展
"""
def stage_class(
name: str
):
def stage_class(name: str):
def decorator(cls: typing.Type[BootingStage]) -> typing.Type[BootingStage]:
preregistered_stages[name] = cls
return cls
@@ -23,12 +22,11 @@ def stage_class(
class BootingStage(abc.ABC):
"""启动阶段
"""
"""启动阶段"""
name: str = None
@abc.abstractmethod
async def run(self, ap: app.Application):
"""启动
"""
"""启动"""
pass

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import sys
from .. import stage, app
from ...utils import version, proxy, announce, platform
@@ -24,26 +23,22 @@ from ...utils import logcache
from .. import taskmgr
@stage.stage_class("BuildAppStage")
@stage.stage_class('BuildAppStage')
class BuildAppStage(stage.BootingStage):
"""构建应用阶段
"""
"""构建应用阶段"""
async def run(self, ap: app.Application):
"""构建app对象的各个组件对象并初始化
"""
"""构建app对象的各个组件对象并初始化"""
ap.task_mgr = taskmgr.AsyncTaskManager(ap)
discover = discover_engine.ComponentDiscoveryEngine(ap)
discover.discover_blueprint(
"components.yaml"
)
discover.discover_blueprint('components.yaml')
ap.discover = discover
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
@@ -52,14 +47,14 @@ class BuildAppStage(stage.BootingStage):
ap,
backend_url=ap.instance_config.data['telemetry']['url'],
basic_info={
"host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"],
"semantic_version": ver_mgr.get_current_version(),
"platform": platform.get_platform(),
'host_id': identifier.identifier['host_id'],
'instance_id': identifier.identifier['instance_id'],
'semantic_version': ver_mgr.get_current_version(),
'platform': platform.get_platform(),
},
runtime_info={
"admin_id": "{}".format(ap.instance_config.data["admins"]),
"msg_source": str([]),
'admin_id': '{}'.format(ap.instance_config.data['admins']),
'msg_source': str([]),
},
)
ap.ctr_mgr = center_v2_api

View File

@@ -1,20 +1,17 @@
from __future__ import annotations
import secrets
import os
from .. import stage, app
@stage.stage_class("GenKeysStage")
@stage.stage_class('GenKeysStage')
class GenKeysStage(stage.BootingStage):
"""生成密钥阶段
"""
"""生成密钥阶段"""
async def run(self, ap: app.Application):
"""启动
"""
"""启动"""
if not ap.instance_config.data['system']['jwt']['secret']:
ap.instance_config.data['system']['jwt']['secret'] = secrets.token_hex(16)
await ap.instance_config.dump_config()

View File

@@ -7,45 +7,80 @@ from .. import stage, app
from ..bootutils import config
@stage.stage_class("LoadConfigStage")
@stage.stage_class('LoadConfigStage')
class LoadConfigStage(stage.BootingStage):
"""加载配置文件阶段
"""
"""加载配置文件阶段"""
async def run(self, ap: app.Application):
"""启动
"""
"""启动"""
# ======= deprecated =======
if os.path.exists("data/config/command.json"):
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/legacy/command.json", completion=False)
if os.path.exists('data/config/command.json'):
ap.command_cfg = await config.load_json_config(
'data/config/command.json',
'templates/legacy/command.json',
completion=False,
)
if os.path.exists("data/config/pipeline.json"):
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/legacy/pipeline.json", completion=False)
if os.path.exists('data/config/pipeline.json'):
ap.pipeline_cfg = await config.load_json_config(
'data/config/pipeline.json',
'templates/legacy/pipeline.json',
completion=False,
)
if os.path.exists("data/config/platform.json"):
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/legacy/platform.json", completion=False)
if os.path.exists('data/config/platform.json'):
ap.platform_cfg = await config.load_json_config(
'data/config/platform.json',
'templates/legacy/platform.json',
completion=False,
)
if os.path.exists("data/config/provider.json"):
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/legacy/provider.json", completion=False)
if os.path.exists('data/config/provider.json'):
ap.provider_cfg = await config.load_json_config(
'data/config/provider.json',
'templates/legacy/provider.json',
completion=False,
)
if os.path.exists("data/config/system.json"):
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/legacy/system.json", completion=False)
if os.path.exists('data/config/system.json'):
ap.system_cfg = await config.load_json_config(
'data/config/system.json',
'templates/legacy/system.json',
completion=False,
)
if os.path.exists("data/metadata/instance-secret.json"):
ap.instance_secret_meta = await config.load_json_config("data/metadata/instance-secret.json", template_data={
'jwt_secret': secrets.token_hex(16)
})
if os.path.exists('data/metadata/instance-secret.json'):
ap.instance_secret_meta = await config.load_json_config(
'data/metadata/instance-secret.json',
template_data={'jwt_secret': secrets.token_hex(16)},
)
await ap.instance_secret_meta.dump_config()
# ======= deprecated =======
ap.instance_config = await config.load_yaml_config("data/config.yaml", "templates/config.yaml", completion=False)
ap.instance_config = await config.load_yaml_config(
'data/config.yaml', 'templates/config.yaml', completion=False
)
await ap.instance_config.dump_config()
ap.sensitive_meta = await config.load_json_config("data/metadata/sensitive-words.json", "templates/metadata/sensitive-words.json")
ap.sensitive_meta = await config.load_json_config(
'data/metadata/sensitive-words.json',
'templates/metadata/sensitive-words.json',
)
await ap.sensitive_meta.dump_config()
ap.pipeline_config_meta_trigger = await config.load_yaml_config("templates/metadata/pipeline/trigger.yaml", "templates/metadata/pipeline/trigger.yaml")
ap.pipeline_config_meta_safety = await config.load_yaml_config("templates/metadata/pipeline/safety.yaml", "templates/metadata/pipeline/safety.yaml")
ap.pipeline_config_meta_ai = await config.load_yaml_config("templates/metadata/pipeline/ai.yaml", "templates/metadata/pipeline/ai.yaml")
ap.pipeline_config_meta_output = await config.load_yaml_config("templates/metadata/pipeline/output.yaml", "templates/metadata/pipeline/output.yaml")
ap.pipeline_config_meta_trigger = await config.load_yaml_config(
'templates/metadata/pipeline/trigger.yaml',
'templates/metadata/pipeline/trigger.yaml',
)
ap.pipeline_config_meta_safety = await config.load_yaml_config(
'templates/metadata/pipeline/safety.yaml',
'templates/metadata/pipeline/safety.yaml',
)
ap.pipeline_config_meta_ai = await config.load_yaml_config(
'templates/metadata/pipeline/ai.yaml', 'templates/metadata/pipeline/ai.yaml'
)
ap.pipeline_config_meta_output = await config.load_yaml_config(
'templates/metadata/pipeline/output.yaml',
'templates/metadata/pipeline/output.yaml',
)

Some files were not shown because too many files have changed in this diff Show More