mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-13 17:26:04 +00:00
144 lines
6.3 KiB
Python
144 lines
6.3 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import traceback
|
|
|
|
from ..core import app
|
|
from ..core import entities as core_entities
|
|
|
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
|
|
|
|
class Controller:
|
|
"""总控制器"""
|
|
|
|
ap: app.Application
|
|
|
|
semaphore: asyncio.Semaphore = None
|
|
"""请求并发控制信号量"""
|
|
|
|
def __init__(self, ap: app.Application):
|
|
self.ap = ap
|
|
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
|
|
|
|
async def _try_claim_steering_before_session_slot(
|
|
self,
|
|
query: pipeline_query.Query,
|
|
) -> bool:
|
|
"""Claim steering while the normal per-session slot is still busy.
|
|
|
|
Follow-up input must be claimed before it waits behind the session
|
|
semaphore; otherwise the active run can finish before the query reaches
|
|
ChatMessageHandler.try_claim_steering_from_query.
|
|
"""
|
|
try:
|
|
pipeline_uuid = query.pipeline_uuid
|
|
if not pipeline_uuid:
|
|
return False
|
|
|
|
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
|
|
if not pipeline:
|
|
return False
|
|
|
|
session = await self.ap.sess_mgr.get_session(query)
|
|
query.session = session
|
|
query.pipeline_config = pipeline.pipeline_entity.config
|
|
query.variables['_pipeline_bound_plugins'] = pipeline.bound_plugins
|
|
query.variables['_pipeline_bound_mcp_servers'] = pipeline.bound_mcp_servers
|
|
|
|
return await self.ap.agent_run_orchestrator.try_claim_steering_from_query(query)
|
|
except Exception as exc:
|
|
self.ap.logger.warning(
|
|
f'Failed to claim query {query.query_id} as steering input: {exc}',
|
|
exc_info=True,
|
|
)
|
|
return False
|
|
|
|
async def consumer(self):
|
|
"""事件处理循环"""
|
|
try:
|
|
while True:
|
|
selected_query: pipeline_query.Query = None
|
|
claimed_steering_query: pipeline_query.Query = None
|
|
|
|
# 取请求
|
|
async with self.ap.query_pool:
|
|
queries: list[pipeline_query.Query] = self.ap.query_pool.queries
|
|
|
|
for query in queries:
|
|
session = await self.ap.sess_mgr.get_session(query)
|
|
# Debug logging removed from tight loop to prevent excessive log generation
|
|
# that can cause memory overflow in high-traffic scenarios
|
|
|
|
if session._semaphore.locked():
|
|
if await self._try_claim_steering_before_session_slot(query):
|
|
claimed_steering_query = query
|
|
self.ap.logger.debug(f'Claimed query {query.query_id} as steering before session slot')
|
|
break
|
|
continue
|
|
|
|
if not session._semaphore.locked():
|
|
selected_query = query
|
|
await session._semaphore.acquire()
|
|
# Only log when actually selecting a query
|
|
self.ap.logger.debug(f'Selected query {query.query_id} for processing')
|
|
|
|
break
|
|
|
|
if claimed_steering_query:
|
|
queries.remove(claimed_steering_query)
|
|
self.ap.query_pool.cached_queries.pop(claimed_steering_query.query_id, None)
|
|
self.ap.query_pool.condition.notify_all()
|
|
continue
|
|
elif selected_query: # 找到了
|
|
queries.remove(selected_query)
|
|
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
|
|
await self.ap.query_pool.condition.wait()
|
|
continue
|
|
|
|
if selected_query:
|
|
|
|
async def _process_query(selected_query: pipeline_query.Query):
|
|
async with self.semaphore: # 总并发上限
|
|
# find pipeline
|
|
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
|
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
|
|
pipeline_uuid = selected_query.pipeline_uuid
|
|
|
|
if pipeline_uuid:
|
|
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
|
|
if pipeline:
|
|
await pipeline.run(selected_query)
|
|
else:
|
|
self.ap.logger.warning(
|
|
f'Pipeline {pipeline_uuid} not found for query {selected_query.query_id}, query dropped'
|
|
)
|
|
else:
|
|
self.ap.logger.warning(
|
|
f'No pipeline_uuid for query {selected_query.query_id}, query dropped'
|
|
)
|
|
|
|
async with self.ap.query_pool:
|
|
(await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
|
|
# 通知其他协程,有新的请求可以处理了
|
|
self.ap.query_pool.condition.notify_all()
|
|
|
|
self.ap.task_mgr.create_task(
|
|
_process_query(selected_query),
|
|
kind='query',
|
|
name=f'query-{selected_query.query_id}',
|
|
scopes=[
|
|
core_entities.LifecycleControlScope.APPLICATION,
|
|
core_entities.LifecycleControlScope.PLATFORM,
|
|
],
|
|
)
|
|
|
|
except Exception as e:
|
|
# traceback.print_exc()
|
|
self.ap.logger.error(f'控制器循环出错: {e}')
|
|
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
|
|
|
async def run(self):
|
|
"""运行控制器"""
|
|
await self.consumer()
|