diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 133e6a80..17da8161 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -86,6 +86,7 @@ class CommandManager: privilege = 2 ctx = command_context.ExecuteContext( + query_id=query.query_id, session=session, command_text=command_text, command='', diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index b5cf664a..e71ff0dc 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -205,6 +205,7 @@ class RuntimePipeline: self.ap.logger.error(f'Traceback: {traceback.format_exc()}') finally: self.ap.logger.debug(f'Query {query} processed') + del self.ap.query_pool.cached_queries[query.query_id] class PipelineManager: diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index eb32fce6..898cfad6 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -19,12 +19,16 @@ class QueryPool: queries: list[pipeline_query.Query] + cached_queries: dict[int, pipeline_query.Query] + """Cached queries, used for plugin backward api call, will be removed after the query completely processed""" + condition: asyncio.Condition def __init__(self): self.query_id_counter = 0 self.pool_lock = asyncio.Lock() self.queries = [] + self.cached_queries = {} self.condition = asyncio.Condition(self.pool_lock) async def add_query( @@ -39,9 +43,10 @@ class QueryPool: pipeline_uuid: typing.Optional[str] = None, ) -> pipeline_query.Query: async with self.condition: + query_id = self.query_id_counter query = pipeline_query.Query( bot_uuid=bot_uuid, - query_id=self.query_id_counter, + query_id=query_id, launcher_type=launcher_type, launcher_id=launcher_id, sender_id=sender_id, @@ -53,6 +58,7 @@ class QueryPool: pipeline_uuid=pipeline_uuid, ) self.queries.append(query) + self.cached_queries[query_id] = query self.query_id_counter += 1 self.condition.notify_all() diff --git a/pkg/plugin/handler.py b/pkg/plugin/handler.py index a4a77ea6..8cd0f357 100644 --- a/pkg/plugin/handler.py +++ b/pkg/plugin/handler.py @@ -13,7 +13,6 @@ from langbot_plugin.entities.io.actions.enums import ( LangBotToRuntimeAction, PluginToRuntimeAction, ) -import langbot_plugin.api.entities.context as event_context_module import langbot_plugin.api.entities.builtin.platform.message as platform_message from ..entity.persistence import plugin as persistence_plugin @@ -68,21 +67,21 @@ class RuntimeConnectionHandler(handler.Handler): @self.action(PluginToRuntimeAction.REPLY_MESSAGE) async def reply_message(data: dict[str, Any]) -> handler.ActionResponse: """Reply message""" - eid = data['eid'] + query_id = data['query_id'] message_chain = data['message_chain'] quote_origin = data['quote_origin'] - if eid not in event_context_module.cached_event_contexts: + if query_id not in self.ap.query_pool.cached_queries: return handler.ActionResponse.error( - message=f'Event context with eid {eid} not found', + message=f'Query with query_id {query_id} not found', ) - event_context = event_context_module.cached_event_contexts[eid] + query = self.ap.query_pool.cached_queries[query_id] message_chain_obj = platform_message.MessageChain.model_validate(message_chain) - await event_context.event.query.adapter.reply_message( - event_context.event.query.message_event, + await query.adapter.reply_message( + query.message_event, message_chain_obj, quote_origin, ) @@ -91,6 +90,23 @@ class RuntimeConnectionHandler(handler.Handler): data={}, ) + @self.action(PluginToRuntimeAction.GET_BOT_UUID) + async def get_bot_uuid(data: dict[str, Any]) -> handler.ActionResponse: + """Get bot uuid""" + query_id = data['query_id'] + if query_id not in self.ap.query_pool.cached_queries: + return handler.ActionResponse.error( + message=f'Query with query_id {query_id} not found', + ) + + query = self.ap.query_pool.cached_queries[query_id] + + return handler.ActionResponse.success( + data={ + 'bot_uuid': query.bot_uuid, + }, + ) + async def ping(self) -> dict[str, Any]: """Ping the runtime""" return await self.call_action(