From 47e9ce96fca746991ff9e389f019e55f4f8fcaa9 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Wed, 23 Jul 2025 17:29:36 +0800 Subject: [PATCH] feat: add topk --- pkg/api/http/service/knowledge.py | 2 +- pkg/provider/runners/localagent.py | 2 +- pkg/rag/knowledge/kbmgr.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/api/http/service/knowledge.py b/pkg/api/http/service/knowledge.py index 27506ec9..ed4ab008 100644 --- a/pkg/api/http/service/knowledge.py +++ b/pkg/api/http/service/knowledge.py @@ -78,7 +78,7 @@ class KnowledgeService: runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) if runtime_kb is None: raise Exception('Knowledge base not found') - return [result.model_dump() for result in await runtime_kb.retrieve(query)] + return [result.model_dump() for result in await runtime_kb.retrieve(query,runtime_kb.knowledge_base_entity.top_k)] async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]: """获取知识库文件""" diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 1d3e88ac..3ccb5573 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -56,7 +56,7 @@ class LocalAgentRunner(runner.RequestRunner): self.ap.logger.warning(f'Knowledge base {kb_uuid} not found') raise ValueError(f'Knowledge base {kb_uuid} not found') - result = await kb.retrieve(user_message_text) + result = await kb.retrieve(user_message_text,kb.knowledge_base_entity.top_k) final_user_message_text = '' diff --git a/pkg/rag/knowledge/kbmgr.py b/pkg/rag/knowledge/kbmgr.py index a9e7e57a..1cdef361 100644 --- a/pkg/rag/knowledge/kbmgr.py +++ b/pkg/rag/knowledge/kbmgr.py @@ -123,11 +123,11 @@ class RuntimeKnowledgeBase: ) return wrapper.id - async def retrieve(self, query: str) -> list[retriever_entities.RetrieveResultEntry]: + async def retrieve(self, query: str, top_k: int) -> list[retriever_entities.RetrieveResultEntry]: embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid( self.knowledge_base_entity.embedding_model_uuid ) - return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model) + return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model, top_k) async def delete_file(self, file_id: str): # delete vector