mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-09 23:36:02 +00:00
refactor: 请求处理控制流基础架构
This commit is contained in:
0
pkg/core/__init__.py
Normal file
0
pkg/core/__init__.py
Normal file
49
pkg/core/app.py
Normal file
49
pkg/core/app.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from ..qqbot import manager as qqbot_mgr
|
||||
from ..openai import manager as openai_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..database import manager as database_mgr
|
||||
from ..utils.center import v2 as center_mgr
|
||||
from ..plugin import host as plugin_host
|
||||
from . import pool, controller
|
||||
from ..pipeline import stagemgr
|
||||
|
||||
|
||||
class Application:
|
||||
im_mgr: qqbot_mgr.QQBotManager = None
|
||||
|
||||
llm_mgr: openai_mgr.OpenAIInteract = None
|
||||
|
||||
cfg_mgr: config_mgr.ConfigManager = None
|
||||
|
||||
tips_mgr: config_mgr.ConfigManager = None
|
||||
|
||||
db_mgr: database_mgr.DatabaseManager = None
|
||||
|
||||
ctr_mgr: center_mgr.V2CenterAPI = None
|
||||
|
||||
query_pool: pool.QueryPool = None
|
||||
|
||||
ctrl: controller.Controller = None
|
||||
|
||||
stage_mgr: stagemgr.StageManager = None
|
||||
|
||||
logger: logging.Logger = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
# TODO make it async
|
||||
plugin_host.initialize_plugins()
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self.im_mgr.run()),
|
||||
asyncio.create_task(self.ctrl.run())
|
||||
]
|
||||
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
138
pkg/core/boot.py
Normal file
138
pkg/core/boot.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from .bootutils import files
|
||||
from .bootutils import deps
|
||||
from .bootutils import log
|
||||
from .bootutils import config
|
||||
|
||||
from . import app
|
||||
from . import pool
|
||||
from . import controller
|
||||
from ..pipeline import stagemgr
|
||||
from ..audit import identifier
|
||||
from ..database import manager as db_mgr
|
||||
from ..openai import manager as llm_mgr
|
||||
from ..openai import session as llm_session
|
||||
from ..openai import dprompt as llm_dprompt
|
||||
from ..qqbot import manager as im_mgr
|
||||
from ..qqbot.cmds import aamgr as im_cmd_aamgr
|
||||
from ..plugin import host as plugin_host
|
||||
from ..utils.center import v2 as center_v2
|
||||
from ..utils import updater
|
||||
from ..utils import context
|
||||
|
||||
use_override = False
|
||||
|
||||
|
||||
async def make_app() -> app.Application:
|
||||
global use_override
|
||||
|
||||
generated_files = await files.generate_files()
|
||||
|
||||
if generated_files:
|
||||
print("以下文件不存在,已自动生成,请修改配置文件后重启:")
|
||||
for file in generated_files:
|
||||
print("-", file)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
missing_deps = await deps.check_deps()
|
||||
|
||||
if missing_deps:
|
||||
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
|
||||
for dep in missing_deps:
|
||||
print("-", dep)
|
||||
await deps.install_deps(missing_deps)
|
||||
sys.exit(0)
|
||||
|
||||
qcg_logger = await log.init_logging()
|
||||
|
||||
# 生成标识符
|
||||
identifier.init()
|
||||
|
||||
cfg_mgr = await config.load_python_module_config(
|
||||
"config.py",
|
||||
"config-template.py"
|
||||
)
|
||||
context.set_config_manager(cfg_mgr)
|
||||
cfg = cfg_mgr.data
|
||||
|
||||
# 检查是否携带了 --override 或 -r 参数
|
||||
if '--override' in sys.argv or '-r' in sys.argv:
|
||||
use_override = True
|
||||
|
||||
if use_override:
|
||||
overrided = await config.override_config_manager(cfg_mgr)
|
||||
if overrided:
|
||||
qcg_logger.info("以下配置项已使用 override.json 覆盖:" + ",".join(overrided))
|
||||
|
||||
tips_mgr = await config.load_python_module_config(
|
||||
"tips.py",
|
||||
"tips-custom-template.py"
|
||||
)
|
||||
|
||||
# 检查管理员QQ号
|
||||
if cfg_mgr.data['admin_qq'] == 0:
|
||||
qcg_logger.warning("未设置管理员QQ号,将无法使用管理员命令,请在 config.py 中修改 admin_qq")
|
||||
|
||||
# TODO make it async
|
||||
llm_dprompt.register_all()
|
||||
im_cmd_aamgr.register_all()
|
||||
im_cmd_aamgr.apply_privileges()
|
||||
|
||||
# 构建组建实例
|
||||
ap = app.Application()
|
||||
ap.logger = qcg_logger
|
||||
ap.cfg_mgr = cfg_mgr
|
||||
ap.tips_mgr = tips_mgr
|
||||
|
||||
ap.query_pool = pool.QueryPool()
|
||||
|
||||
center_v2_api = center_v2.V2CenterAPI(
|
||||
basic_info={
|
||||
"host_id": identifier.identifier['host_id'],
|
||||
"instance_id": identifier.identifier['instance_id'],
|
||||
"semantic_version": updater.get_current_tag(),
|
||||
"platform": sys.platform,
|
||||
},
|
||||
runtime_info={
|
||||
"admin_id": "{}".format(cfg['admin_qq']),
|
||||
"msg_source": cfg['msg_source_adapter'],
|
||||
}
|
||||
)
|
||||
ap.ctr_mgr = center_v2_api
|
||||
|
||||
db_mgr_inst = db_mgr.DatabaseManager(ap)
|
||||
# TODO make it async
|
||||
db_mgr_inst.initialize_database()
|
||||
ap.db_mgr = db_mgr_inst
|
||||
|
||||
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
|
||||
ap.llm_mgr = llm_mgr_inst
|
||||
# TODO make it async
|
||||
llm_session.load_sessions()
|
||||
|
||||
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.im_mgr = im_mgr_inst
|
||||
|
||||
stage_mgr = stagemgr.StageManager(ap)
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
||||
# TODO make it async
|
||||
plugin_host.load_plugins()
|
||||
# plugin_host.initialize_plugins()
|
||||
|
||||
return ap
|
||||
|
||||
|
||||
async def main():
|
||||
app_inst = await make_app()
|
||||
await app_inst.run()
|
||||
0
pkg/core/bootutils/__init__.py
Normal file
0
pkg/core/bootutils/__init__.py
Normal file
21
pkg/core/bootutils/config.py
Normal file
21
pkg/core/bootutils/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import json
|
||||
|
||||
from ...config import manager as config_mgr
|
||||
from ...config.impls import pymodule
|
||||
|
||||
|
||||
load_python_module_config = config_mgr.load_python_module_config
|
||||
load_json_config = config_mgr.load_json_config
|
||||
|
||||
|
||||
async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]:
|
||||
override_json = json.load(open("override.json", "r", encoding="utf-8"))
|
||||
overrided = []
|
||||
|
||||
config = cfg_mgr.data
|
||||
for key in override_json:
|
||||
if key in config:
|
||||
config[key] = override_json[key]
|
||||
overrided.append(key)
|
||||
|
||||
return overrided
|
||||
34
pkg/core/bootutils/deps.py
Normal file
34
pkg/core/bootutils/deps.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import pip
|
||||
|
||||
required_deps = {
|
||||
"requests": "requests",
|
||||
"openai": "openai",
|
||||
"dulwich": "dulwich",
|
||||
"colorlog": "colorlog",
|
||||
"mirai": "yiri-mirai-rc",
|
||||
"func_timeout": "func_timeout",
|
||||
"PIL": "pillow",
|
||||
"nakuru": "nakuru-project-idk",
|
||||
"CallingGPT": "CallingGPT",
|
||||
"tiktoken": "tiktoken",
|
||||
"yaml": "pyyaml",
|
||||
"aiohttp": "aiohttp",
|
||||
}
|
||||
|
||||
|
||||
async def check_deps() -> list[str]:
|
||||
global required_deps
|
||||
|
||||
missing_deps = []
|
||||
for dep in required_deps:
|
||||
try:
|
||||
__import__(dep)
|
||||
except ImportError:
|
||||
missing_deps.append(dep)
|
||||
return missing_deps
|
||||
|
||||
async def install_deps(deps: list[str]):
|
||||
global required_deps
|
||||
|
||||
for dep in deps:
|
||||
pip.main(["install", required_deps[dep]])
|
||||
37
pkg/core/bootutils/files.py
Normal file
37
pkg/core/bootutils/files.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
|
||||
required_files = {
|
||||
"config.py": "config-template.py",
|
||||
"banlist.py": "banlist-template.py",
|
||||
"tips.py": "tips-custom-template.py",
|
||||
"sensitive.json": "res/templates/sensitive-template.json",
|
||||
"scenario/default.json": "scenario/default-template.json",
|
||||
"cmdpriv.json": "res/templates/cmdpriv-template.json",
|
||||
}
|
||||
|
||||
required_paths = [
|
||||
"plugins",
|
||||
"prompts",
|
||||
"temp",
|
||||
"logs"
|
||||
]
|
||||
|
||||
async def generate_files() -> list[str]:
|
||||
global required_files, required_paths
|
||||
|
||||
for required_paths in required_paths:
|
||||
if not os.path.exists(required_paths):
|
||||
os.mkdir(required_paths)
|
||||
|
||||
generated_files = []
|
||||
for file in required_files:
|
||||
if not os.path.exists(file):
|
||||
shutil.copyfile(required_files[file], file)
|
||||
generated_files.append(file)
|
||||
|
||||
return generated_files
|
||||
56
pkg/core/bootutils/log.py
Normal file
56
pkg/core/bootutils/log.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import colorlog
|
||||
|
||||
|
||||
log_colors_config = {
|
||||
"DEBUG": "green", # cyan white
|
||||
"INFO": "white",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "cyan",
|
||||
}
|
||||
|
||||
|
||||
async def init_logging() -> logging.Logger:
|
||||
level = logging.INFO
|
||||
|
||||
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
|
||||
level = logging.DEBUG
|
||||
|
||||
log_file_name = "logs/qcg-%s.log" % time.strftime(
|
||||
"%Y-%m-%d-%H-%M-%S", time.localtime()
|
||||
)
|
||||
|
||||
qcg_logger = logging.getLogger("qcg")
|
||||
|
||||
qcg_logger.setLevel(level)
|
||||
|
||||
color_formatter = colorlog.ColoredFormatter(
|
||||
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
log_colors=log_colors_config,
|
||||
)
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)]
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(color_formatter)
|
||||
qcg_logger.addHandler(handler)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # 设置日志输出格式
|
||||
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
||||
# 日志输出的格式
|
||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
|
||||
handlers=[logging.NullHandler()],
|
||||
)
|
||||
|
||||
return qcg_logger
|
||||
0
pkg/core/bootutils/misc.py
Normal file
0
pkg/core/bootutils/misc.py
Normal file
84
pkg/core/controller.py
Normal file
84
pkg/core/controller.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from . import app, entities
|
||||
from ..pipeline import entities as pipeline_entities
|
||||
|
||||
DEFAULT_QUERY_CONCURRENCY = 10
|
||||
|
||||
|
||||
class Controller:
|
||||
"""总控制器
|
||||
"""
|
||||
ap: app.Application
|
||||
|
||||
semaphore: asyncio.Semaphore = None
|
||||
"""请求并发控制信号量"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY)
|
||||
|
||||
async def consumer(self):
|
||||
"""事件处理循环
|
||||
"""
|
||||
while True:
|
||||
selected_query: entities.Query = None
|
||||
|
||||
# 取请求
|
||||
async with self.ap.query_pool:
|
||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
||||
|
||||
if queries:
|
||||
selected_query = queries.pop(0) # FCFS
|
||||
else:
|
||||
await self.ap.query_pool.condition.wait()
|
||||
continue
|
||||
|
||||
if selected_query:
|
||||
async def _process_query(selected_query):
|
||||
async with self.semaphore:
|
||||
await self.process_query(selected_query)
|
||||
|
||||
asyncio.create_task(_process_query(selected_query))
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
"""
|
||||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
try:
|
||||
for stage_container in self.ap.stage_mgr.stage_containers:
|
||||
res = await stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}")
|
||||
|
||||
if res.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
query.message_event,
|
||||
res.user_notice
|
||||
)
|
||||
if res.debug_notice:
|
||||
self.ap.logger.debug(res.debug_notice)
|
||||
if res.console_notice:
|
||||
self.ap.logger.info(res.console_notice)
|
||||
|
||||
if res.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif res.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = res.new_query
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
|
||||
async def run(self):
|
||||
"""运行控制器
|
||||
"""
|
||||
await self.consumer()
|
||||
41
pkg/core/entities.py
Normal file
41
pkg/core/entities.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
|
||||
PERSON = 'person'
|
||||
"""私聊"""
|
||||
|
||||
GROUP = 'group'
|
||||
"""群聊"""
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""一次请求的信息封装"""
|
||||
|
||||
query_id: int
|
||||
"""请求ID"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
"""会话类型"""
|
||||
|
||||
launcher_id: int
|
||||
"""会话ID"""
|
||||
|
||||
sender_id: int
|
||||
"""发送者ID"""
|
||||
|
||||
message_event: mirai.MessageEvent
|
||||
"""事件"""
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
"""消息链"""
|
||||
|
||||
resp_message_chain: typing.Optional[mirai.MessageChain] = None
|
||||
"""回复消息链"""
|
||||
52
pkg/core/pool.py
Normal file
52
pkg/core/pool.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import mirai
|
||||
|
||||
from . import entities
|
||||
|
||||
|
||||
class QueryPool:
|
||||
|
||||
query_id_counter: int = 0
|
||||
|
||||
pool_lock: asyncio.Lock
|
||||
|
||||
queries: list[entities.Query]
|
||||
|
||||
condition: asyncio.Condition
|
||||
|
||||
def __init__(self):
|
||||
self.query_id_counter = 0
|
||||
self.pool_lock = asyncio.Lock()
|
||||
self.queries = []
|
||||
self.condition = asyncio.Condition(self.pool_lock)
|
||||
|
||||
async def add_query(
|
||||
self,
|
||||
launcher_type: entities.LauncherTypes,
|
||||
launcher_id: int,
|
||||
sender_id: int,
|
||||
message_event: mirai.MessageEvent,
|
||||
message_chain: mirai.MessageChain
|
||||
) -> entities.Query:
|
||||
async with self.condition:
|
||||
query = entities.Query(
|
||||
query_id=self.query_id_counter,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain
|
||||
)
|
||||
self.queries.append(query)
|
||||
self.query_id_counter += 1
|
||||
self.condition.notify_all()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.pool_lock.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
self.pool_lock.release()
|
||||
Reference in New Issue
Block a user