refactor: 请求处理控制流基础架构

This commit is contained in:
RockChinQ
2024-01-26 15:51:49 +08:00
parent a064c24f60
commit 8d084427d2
55 changed files with 1430 additions and 146 deletions

0
pkg/core/__init__.py Normal file
View File

49
pkg/core/app.py Normal file
View File

@@ -0,0 +1,49 @@
from __future__ import annotations
import logging
import asyncio
from ..qqbot import manager as qqbot_mgr
from ..openai import manager as openai_mgr
from ..config import manager as config_mgr
from ..database import manager as database_mgr
from ..utils.center import v2 as center_mgr
from ..plugin import host as plugin_host
from . import pool, controller
from ..pipeline import stagemgr
class Application:
im_mgr: qqbot_mgr.QQBotManager = None
llm_mgr: openai_mgr.OpenAIInteract = None
cfg_mgr: config_mgr.ConfigManager = None
tips_mgr: config_mgr.ConfigManager = None
db_mgr: database_mgr.DatabaseManager = None
ctr_mgr: center_mgr.V2CenterAPI = None
query_pool: pool.QueryPool = None
ctrl: controller.Controller = None
stage_mgr: stagemgr.StageManager = None
logger: logging.Logger = None
def __init__(self):
pass
async def run(self):
# TODO make it async
plugin_host.initialize_plugins()
tasks = [
asyncio.create_task(self.im_mgr.run()),
asyncio.create_task(self.ctrl.run())
]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

138
pkg/core/boot.py Normal file
View File

@@ -0,0 +1,138 @@
from __future__ import print_function
import os
import sys
from .bootutils import files
from .bootutils import deps
from .bootutils import log
from .bootutils import config
from . import app
from . import pool
from . import controller
from ..pipeline import stagemgr
from ..audit import identifier
from ..database import manager as db_mgr
from ..openai import manager as llm_mgr
from ..openai import session as llm_session
from ..openai import dprompt as llm_dprompt
from ..qqbot import manager as im_mgr
from ..qqbot.cmds import aamgr as im_cmd_aamgr
from ..plugin import host as plugin_host
from ..utils.center import v2 as center_v2
from ..utils import updater
from ..utils import context
use_override = False
async def make_app() -> app.Application:
global use_override
generated_files = await files.generate_files()
if generated_files:
print("以下文件不存在,已自动生成,请修改配置文件后重启:")
for file in generated_files:
print("-", file)
sys.exit(0)
missing_deps = await deps.check_deps()
if missing_deps:
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
for dep in missing_deps:
print("-", dep)
await deps.install_deps(missing_deps)
sys.exit(0)
qcg_logger = await log.init_logging()
# 生成标识符
identifier.init()
cfg_mgr = await config.load_python_module_config(
"config.py",
"config-template.py"
)
context.set_config_manager(cfg_mgr)
cfg = cfg_mgr.data
# 检查是否携带了 --override 或 -r 参数
if '--override' in sys.argv or '-r' in sys.argv:
use_override = True
if use_override:
overrided = await config.override_config_manager(cfg_mgr)
if overrided:
qcg_logger.info("以下配置项已使用 override.json 覆盖:" + ",".join(overrided))
tips_mgr = await config.load_python_module_config(
"tips.py",
"tips-custom-template.py"
)
# 检查管理员QQ号
if cfg_mgr.data['admin_qq'] == 0:
qcg_logger.warning("未设置管理员QQ号将无法使用管理员命令请在 config.py 中修改 admin_qq")
# TODO make it async
llm_dprompt.register_all()
im_cmd_aamgr.register_all()
im_cmd_aamgr.apply_privileges()
# 构建组建实例
ap = app.Application()
ap.logger = qcg_logger
ap.cfg_mgr = cfg_mgr
ap.tips_mgr = tips_mgr
ap.query_pool = pool.QueryPool()
center_v2_api = center_v2.V2CenterAPI(
basic_info={
"host_id": identifier.identifier['host_id'],
"instance_id": identifier.identifier['instance_id'],
"semantic_version": updater.get_current_tag(),
"platform": sys.platform,
},
runtime_info={
"admin_id": "{}".format(cfg['admin_qq']),
"msg_source": cfg['msg_source_adapter'],
}
)
ap.ctr_mgr = center_v2_api
db_mgr_inst = db_mgr.DatabaseManager(ap)
# TODO make it async
db_mgr_inst.initialize_database()
ap.db_mgr = db_mgr_inst
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
ap.llm_mgr = llm_mgr_inst
# TODO make it async
llm_session.load_sessions()
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
await im_mgr_inst.initialize()
ap.im_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap)
ap.ctrl = ctrl
# TODO make it async
plugin_host.load_plugins()
# plugin_host.initialize_plugins()
return ap
async def main():
app_inst = await make_app()
await app_inst.run()

View File

View File

@@ -0,0 +1,21 @@
import json
from ...config import manager as config_mgr
from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config
load_json_config = config_mgr.load_json_config
async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]:
override_json = json.load(open("override.json", "r", encoding="utf-8"))
overrided = []
config = cfg_mgr.data
for key in override_json:
if key in config:
config[key] = override_json[key]
overrided.append(key)
return overrided

View File

@@ -0,0 +1,34 @@
import pip
required_deps = {
"requests": "requests",
"openai": "openai",
"dulwich": "dulwich",
"colorlog": "colorlog",
"mirai": "yiri-mirai-rc",
"func_timeout": "func_timeout",
"PIL": "pillow",
"nakuru": "nakuru-project-idk",
"CallingGPT": "CallingGPT",
"tiktoken": "tiktoken",
"yaml": "pyyaml",
"aiohttp": "aiohttp",
}
async def check_deps() -> list[str]:
global required_deps
missing_deps = []
for dep in required_deps:
try:
__import__(dep)
except ImportError:
missing_deps.append(dep)
return missing_deps
async def install_deps(deps: list[str]):
global required_deps
for dep in deps:
pip.main(["install", required_deps[dep]])

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
import os
import shutil
import sys
required_files = {
"config.py": "config-template.py",
"banlist.py": "banlist-template.py",
"tips.py": "tips-custom-template.py",
"sensitive.json": "res/templates/sensitive-template.json",
"scenario/default.json": "scenario/default-template.json",
"cmdpriv.json": "res/templates/cmdpriv-template.json",
}
required_paths = [
"plugins",
"prompts",
"temp",
"logs"
]
async def generate_files() -> list[str]:
global required_files, required_paths
for required_paths in required_paths:
if not os.path.exists(required_paths):
os.mkdir(required_paths)
generated_files = []
for file in required_files:
if not os.path.exists(file):
shutil.copyfile(required_files[file], file)
generated_files.append(file)
return generated_files

56
pkg/core/bootutils/log.py Normal file
View File

@@ -0,0 +1,56 @@
import logging
import os
import sys
import time
import colorlog
log_colors_config = {
"DEBUG": "green", # cyan white
"INFO": "white",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "cyan",
}
async def init_logging() -> logging.Logger:
level = logging.INFO
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
level = logging.DEBUG
log_file_name = "logs/qcg-%s.log" % time.strftime(
"%Y-%m-%d-%H-%M-%S", time.localtime()
)
qcg_logger = logging.getLogger("qcg")
qcg_logger.setLevel(level)
color_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
log_colors=log_colors_config,
)
stream_handler = logging.StreamHandler(sys.stdout)
log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)]
for handler in log_handlers:
handler.setLevel(level)
handler.setFormatter(color_formatter)
qcg_logger.addHandler(handler)
logging.basicConfig(
level=logging.INFO, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
# 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
handlers=[logging.NullHandler()],
)
return qcg_logger

View File

84
pkg/core/controller.py Normal file
View File

@@ -0,0 +1,84 @@
from __future__ import annotations
import asyncio
import traceback
from . import app, entities
from ..pipeline import entities as pipeline_entities
DEFAULT_QUERY_CONCURRENCY = 10
class Controller:
"""总控制器
"""
ap: app.Application
semaphore: asyncio.Semaphore = None
"""请求并发控制信号量"""
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY)
async def consumer(self):
"""事件处理循环
"""
while True:
selected_query: entities.Query = None
# 取请求
async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries
if queries:
selected_query = queries.pop(0) # FCFS
else:
await self.ap.query_pool.condition.wait()
continue
if selected_query:
async def _process_query(selected_query):
async with self.semaphore:
await self.process_query(selected_query)
asyncio.create_task(_process_query(selected_query))
async def process_query(self, query: entities.Query):
"""处理请求
"""
self.ap.logger.debug(f"Processing query {query}")
try:
for stage_container in self.ap.stage_mgr.stage_containers:
res = await stage_container.inst.process(query, stage_container.inst_name)
self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}")
if res.user_notice:
await self.ap.im_mgr.send(
query.message_event,
res.user_notice
)
if res.debug_notice:
self.ap.logger.debug(res.debug_notice)
if res.console_notice:
self.ap.logger.info(res.console_notice)
if res.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif res.result_type == pipeline_entities.ResultType.CONTINUE:
query = res.new_query
continue
except Exception as e:
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")
async def run(self):
"""运行控制器
"""
await self.consumer()

41
pkg/core/entities.py Normal file
View File

@@ -0,0 +1,41 @@
from __future__ import annotations
import enum
import typing
import pydantic
import mirai
class LauncherTypes(enum.Enum):
PERSON = 'person'
"""私聊"""
GROUP = 'group'
"""群聊"""
class Query(pydantic.BaseModel):
"""一次请求的信息封装"""
query_id: int
"""请求ID"""
launcher_type: LauncherTypes
"""会话类型"""
launcher_id: int
"""会话ID"""
sender_id: int
"""发送者ID"""
message_event: mirai.MessageEvent
"""事件"""
message_chain: mirai.MessageChain
"""消息链"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链"""

52
pkg/core/pool.py Normal file
View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import asyncio
import mirai
from . import entities
class QueryPool:
query_id_counter: int = 0
pool_lock: asyncio.Lock
queries: list[entities.Query]
condition: asyncio.Condition
def __init__(self):
self.query_id_counter = 0
self.pool_lock = asyncio.Lock()
self.queries = []
self.condition = asyncio.Condition(self.pool_lock)
async def add_query(
self,
launcher_type: entities.LauncherTypes,
launcher_id: int,
sender_id: int,
message_event: mirai.MessageEvent,
message_chain: mirai.MessageChain
) -> entities.Query:
async with self.condition:
query = entities.Query(
query_id=self.query_id_counter,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain
)
self.queries.append(query)
self.query_id_counter += 1
self.condition.notify_all()
async def __aenter__(self):
await self.pool_lock.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.pool_lock.release()