feat: add topk

This commit is contained in:
WangCham
2025-07-23 17:29:36 +08:00
parent 4e95bc542c
commit 47e9ce96fc
3 changed files with 4 additions and 4 deletions

View File

@@ -78,7 +78,7 @@ class KnowledgeService:
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
if runtime_kb is None:
raise Exception('Knowledge base not found')
return [result.model_dump() for result in await runtime_kb.retrieve(query)]
return [result.model_dump() for result in await runtime_kb.retrieve(query,runtime_kb.knowledge_base_entity.top_k)]
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
"""获取知识库文件"""

View File

@@ -56,7 +56,7 @@ class LocalAgentRunner(runner.RequestRunner):
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found')
raise ValueError(f'Knowledge base {kb_uuid} not found')
result = await kb.retrieve(user_message_text)
result = await kb.retrieve(user_message_text,kb.knowledge_base_entity.top_k)
final_user_message_text = ''

View File

@@ -123,11 +123,11 @@ class RuntimeKnowledgeBase:
)
return wrapper.id
async def retrieve(self, query: str) -> list[retriever_entities.RetrieveResultEntry]:
async def retrieve(self, query: str, top_k: int) -> list[retriever_entities.RetrieveResultEntry]:
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
self.knowledge_base_entity.embedding_model_uuid
)
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model)
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model, top_k)
async def delete_file(self, file_id: str):
# delete vector