Files
LangBot/pkg/core/controller.py
2024-01-26 15:51:49 +08:00

85 lines
2.7 KiB
Python

from __future__ import annotations
import asyncio
import traceback
from . import app, entities
from ..pipeline import entities as pipeline_entities
DEFAULT_QUERY_CONCURRENCY = 10
class Controller:
"""总控制器
"""
ap: app.Application
semaphore: asyncio.Semaphore = None
"""请求并发控制信号量"""
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY)
async def consumer(self):
"""事件处理循环
"""
while True:
selected_query: entities.Query = None
# 取请求
async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries
if queries:
selected_query = queries.pop(0) # FCFS
else:
await self.ap.query_pool.condition.wait()
continue
if selected_query:
async def _process_query(selected_query):
async with self.semaphore:
await self.process_query(selected_query)
asyncio.create_task(_process_query(selected_query))
async def process_query(self, query: entities.Query):
"""处理请求
"""
self.ap.logger.debug(f"Processing query {query}")
try:
for stage_container in self.ap.stage_mgr.stage_containers:
res = await stage_container.inst.process(query, stage_container.inst_name)
self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}")
if res.user_notice:
await self.ap.im_mgr.send(
query.message_event,
res.user_notice
)
if res.debug_notice:
self.ap.logger.debug(res.debug_notice)
if res.console_notice:
self.ap.logger.info(res.console_notice)
if res.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif res.result_type == pipeline_entities.ResultType.CONTINUE:
query = res.new_query
continue
except Exception as e:
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")
async def run(self):
"""运行控制器
"""
await self.consumer()