From 4a319b2b20ca1f8cfc4b3cde5ffe3441326206be Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 13 Jul 2025 18:41:04 +0800 Subject: [PATCH] feat: query-based apis --- pkg/pipeline/pool.py | 1 + pkg/pipeline/preproc/preproc.py | 3 +- pkg/plugin/handler.py | 56 +++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index 898cfad6..eb7df66b 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -52,6 +52,7 @@ class QueryPool: sender_id=sender_id, message_event=message_event, message_chain=message_chain, + variables={}, resp_messages=[], resp_message_chain=[], adapter=adapter, diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index b48ced64..8cbcf8c7 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -62,13 +62,14 @@ class PreProcessor(stage.PipelineStage): if llm_model.model_entity.abilities.__contains__('func_call'): query.use_funcs = await self.ap.tool_mgr.get_all_tools() - query.variables = { + variables = { 'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}', 'conversation_id': conversation.uuid, 'msg_create_time': ( int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()) ), } + query.variables.update(variables) # Check if this model supports vision, if not, remove all images # TODO this checking should be performed in runner, and in this stage, the image should be reserved diff --git a/pkg/plugin/handler.py b/pkg/plugin/handler.py index 8cd0f357..77fc46bd 100644 --- a/pkg/plugin/handler.py +++ b/pkg/plugin/handler.py @@ -107,6 +107,62 @@ class RuntimeConnectionHandler(handler.Handler): }, ) + @self.action(PluginToRuntimeAction.SET_QUERY_VAR) + async def set_query_var(data: dict[str, Any]) -> handler.ActionResponse: + """Set query var""" + query_id = data['query_id'] + key = data['key'] + value = data['value'] + + if query_id not in self.ap.query_pool.cached_queries: + return handler.ActionResponse.error( + message=f'Query with query_id {query_id} not found', + ) + + query = self.ap.query_pool.cached_queries[query_id] + + query.variables[key] = value + + return handler.ActionResponse.success( + data={}, + ) + + @self.action(PluginToRuntimeAction.GET_QUERY_VAR) + async def get_query_var(data: dict[str, Any]) -> handler.ActionResponse: + """Get query var""" + query_id = data['query_id'] + key = data['key'] + + if query_id not in self.ap.query_pool.cached_queries: + return handler.ActionResponse.error( + message=f'Query with query_id {query_id} not found', + ) + + query = self.ap.query_pool.cached_queries[query_id] + + return handler.ActionResponse.success( + data={ + 'value': query.variables[key], + }, + ) + + @self.action(PluginToRuntimeAction.GET_QUERY_VARS) + async def get_query_vars(data: dict[str, Any]) -> handler.ActionResponse: + """Get query vars""" + query_id = data['query_id'] + if query_id not in self.ap.query_pool.cached_queries: + return handler.ActionResponse.error( + message=f'Query with query_id {query_id} not found', + ) + + query = self.ap.query_pool.cached_queries[query_id] + + return handler.ActionResponse.success( + data={ + 'vars': query.variables, + }, + ) + async def ping(self) -> dict[str, Any]: """Ping the runtime""" return await self.call_action(