diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 8cf51463..42bb5b4c 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -33,7 +33,7 @@ class PreProcessor(stage.PipelineStage): """ session = await self.ap.sess_mgr.get_session(query) - conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config) + conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config['ai']['local-agent']['prompt']) # 设置query query.session = session diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index 93b1146e..5143f2bb 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -42,7 +42,7 @@ class SessionManager: self.session_list.append(session) return session - async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, pipeline_config: dict) -> core_entities.Conversation: + async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, prompt_config: list[dict]) -> core_entities.Conversation: """获取对话或创建对话""" if not session.conversations: @@ -51,7 +51,7 @@ class SessionManager: # set prompt prompt_messages = [] - for prompt_message in pipeline_config['ai']['local-agent']['prompt']: + for prompt_message in prompt_config: prompt_messages.append(provider_entities.Message(**prompt_message)) prompt = provider_entities.Prompt(