chore: stash

This commit is contained in:
Junyan Qin
2025-07-15 22:09:10 +08:00
parent 199164fc4b
commit 67bc065ccd
15 changed files with 508 additions and 338 deletions

View File

@@ -14,7 +14,7 @@ preregistered_groups: list[type[RouterGroup]] = []
"""RouterGroup 的预注册列表"""
def group_class(name: str, path: str) -> None:
def group_class(name: str, path: str) -> typing.Callable[[typing.Type[RouterGroup]], typing.Type[RouterGroup]]:
"""注册一个 RouterGroup"""
def decorator(cls: typing.Type[RouterGroup]) -> typing.Type[RouterGroup]:
@@ -120,6 +120,6 @@ class RouterGroup(abc.ABC):
}
)
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
def http_status(self, status: int, code: int, msg: str) -> typing.Tuple[quart.Response, int]:
"""返回一个指定状态码的响应"""
return self.fail(code, msg), status
return (self.fail(code, msg), status)

View File

@@ -5,77 +5,49 @@ from ... import group
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
class KnowledgeBaseRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases')
async def handle_knowledge_bases() -> str:
@self.route('', methods=['POST', 'GET'])
async def handle_knowledge_bases() -> quart.Response:
if quart.request.method == 'GET':
knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases()
bases_list = [
{
'uuid': kb.id,
'name': kb.name,
'description': kb.description,
'embedding_model_uuid': kb.embedding_model_uuid,
'top_k': kb.top_k,
}
for kb in knowledge_bases
]
return self.success(data={'bases': bases_list})
knowledge_bases = await self.ap.knowledge_service.get_knowledge_bases()
return self.success(data={'bases': knowledge_bases})
elif quart.request.method == 'POST':
json_data = await quart.request.json
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
json_data.get('name'),
json_data.get('description'),
json_data.get('embedding_model_uuid'),
json_data.get('top_k',5),
)
knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data)
return self.success(data={'uuid': knowledge_base_uuid})
return self.http_status(405, -1, 'Method not allowed')
@self.route(
'/<knowledge_base_uuid>',
methods=['GET', 'DELETE'],
endpoint='handle_specific_knowledge_base',
)
async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str:
async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> quart.Response:
if quart.request.method == 'GET':
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid)
knowledge_base = await self.ap.knowledge_service.get_knowledge_base(knowledge_base_uuid)
if knowledge_base is None:
return self.http_status(404, -1, 'knowledge base not found')
return self.success(
data={
'base': {
'name': knowledge_base.name,
'description': knowledge_base.description,
'uuid': knowledge_base.id,
'embedding_model_uuid': knowledge_base.embedding_model_uuid,
'top_k': knowledge_base.top_k,
},
'base': knowledge_base,
}
)
elif quart.request.method == 'DELETE':
await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid)
await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid)
return self.success({})
@self.route(
'/<knowledge_base_uuid>/files',
methods=['GET', 'POST'],
endpoint='get_knowledge_base_files',
)
async def get_knowledge_base_files(knowledge_base_uuid: str) -> str:
if quart.request.method == 'GET':
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid)
files = await self.ap.knowledge_service.get_files_by_knowledge_base(knowledge_base_uuid)
return self.success(
data={
'files': [
{
'id': file.id,
'file_name': file.file_name,
'status': file.status,
}
for file in files
],
'files': files,
}
)
@@ -86,14 +58,17 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
return self.http_status(400, -1, 'File ID is required')
# 调用服务层方法将文件与知识库关联
await self.ap.knowledge_base_service.relate_file_id_with_kb(knowledge_base_uuid, file_id)
return self.success({})
task_id = await self.ap.knowledge_service.store_file(knowledge_base_uuid, file_id)
return self.success(
{
'task_id': task_id,
}
)
@self.route(
'/<knowledge_base_uuid>/files/<file_id>',
methods=['DELETE'],
endpoint='delete_specific_file_in_kb',
)
async def delete_specific_file_in_kb(file_id: str,knowledge_base_uuid: str) -> str:
await self.ap.knowledge_base_service.delete_data_by_file_id(file_id)
async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str:
await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id)
return self.success({})

View File

@@ -0,0 +1,90 @@
from __future__ import annotations
import uuid
import sqlalchemy
from ....core import app
from ....entity.persistence import rag as persistence_rag
class KnowledgeService:
"""知识库服务"""
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_knowledge_bases(self) -> list[dict]:
"""获取所有知识库"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
knowledge_bases = result.all()
return [
self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
for knowledge_base in knowledge_bases
]
async def get_knowledge_base(self, kb_uuid: str) -> dict | None:
"""获取知识库"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
)
knowledge_base = result.first()
if knowledge_base is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
async def create_knowledge_base(self, kb_data: dict) -> str:
"""创建知识库"""
kb_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data))
kb = await self.get_knowledge_base(kb_data['uuid'])
await self.ap.rag_mgr.load_knowledge_base(kb)
return kb_data['uuid']
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
"""更新知识库"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_rag.KnowledgeBase)
.values(kb_data)
.where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
)
await self.ap.rag_mgr.remove_knowledge_base(kb_uuid)
kb = await self.get_knowledge_base(kb_uuid)
await self.ap.rag_mgr.load_knowledge_base(kb)
async def store_file(self, kb_uuid: str, file_id: str) -> int:
"""存储文件"""
# await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(kb_id=kb_uuid, file_id=file_id))
# await self.ap.rag_mgr.store_file(file_id)
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
if runtime_kb is None:
raise Exception('Knowledge base not found')
return await runtime_kb.store_file(file_id)
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
"""获取知识库文件"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
)
files = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_rag.File, file) for file in files]
async def delete_file(self, kb_uuid: str, file_id: str) -> None:
"""删除文件"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
)
# TODO: remove from memory
async def delete_knowledge_base(self, kb_uuid: str) -> None:
"""删除知识库"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
)
# TODO: remove from memory