From 34fe8b324d0b5e1412641ef3b6029372c2327a90 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Thu, 3 Jul 2025 23:28:47 +0800 Subject: [PATCH] feat: add functions --- .../http/controller/groups/knowledge_base.py | 27 ++++++++-------- .../controller/groups/pipelines/pipelines.py | 2 +- pkg/core/app.py | 10 +++--- pkg/core/entities.py | 2 +- pkg/core/stages/build_app.py | 7 ++++ pkg/entity/persistence/vector.py | 14 ++++++++ pkg/rag/knowledge/RAG_Manager.py | 32 +++++++++++-------- pkg/rag/knowledge/services/base_service.py | 8 ++--- pkg/rag/knowledge/services/chroma_manager.py | 4 +-- pkg/rag/knowledge/services/chunker.py | 2 +- pkg/rag/knowledge/services/embedder.py | 8 ++--- pkg/rag/knowledge/services/retriever.py | 8 ++--- 12 files changed, 75 insertions(+), 49 deletions(-) create mode 100644 pkg/entity/persistence/vector.py diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge_base.py index c819397a..f9aa09e0 100644 --- a/pkg/api/http/controller/groups/knowledge_base.py +++ b/pkg/api/http/controller/groups/knowledge_base.py @@ -1,5 +1,4 @@ import quart -from __future__ import annotations from .. import group @group.group_class('knowledge_base', '/api/v1/knowledge/bases') @@ -16,13 +15,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): async def initialize(self) -> None: - rag = self.ap.knowledge_base_service.RAG_Manager() + @self.route('', methods=['POST', 'GET']) async def _() -> str: if quart.request.method == 'GET': - knowledge_bases = await rag.get_all_knowledge_bases() + knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases() bases_list = [ { "uuid": kb.id, @@ -35,17 +34,19 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): msg='ok') json_data = await quart.request.json - knowledge_base_uuid = await rag.create_knowledge_base( + knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( json_data.get('name'), json_data.get('description') ) - return self.success() + return self.success(code=0, + data={}, + msg='ok') - @self.route('/', methods=['GET']) + @self.route('/', methods=['GET','DELETE']) async def _(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - knowledge_base = await rag.get_knowledge_base_by_id(knowledge_base_uuid) + knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid) if knowledge_base is None: return self.http_status(404, -1, 'knowledge base not found') @@ -59,11 +60,14 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): }, msg='ok' ) + elif quart.request.method == 'DELETE': + await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid) + return self.success(code=0, msg='ok') @self.route('//files', methods=['GET']) async def _(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - files = await rag.get_files_by_knowledge_base(knowledge_base_uuid) + files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) return self.success(code=0,data=[{ "id": file.id, "file_name": file.file_name, @@ -73,11 +77,6 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): # delete specific file in knowledge base @self.route('//files/', methods=['DELETE']) async def _(knowledge_base_uuid: str, file_id: str) -> str: - await rag.delete_data_by_file_id(file_id) + await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) return self.success(code=0, msg='ok') - # delete specific kb - @self.route('/', methods=['DELETE']) - async def _(knowledge_base_uuid: str) -> str: - await rag.delete_kb_by_id(knowledge_base_uuid) - return self.success(code=0, msg='ok') diff --git a/pkg/api/http/controller/groups/pipelines/pipelines.py b/pkg/api/http/controller/groups/pipelines/pipelines.py index 96ca239a..1a8036cc 100644 --- a/pkg/api/http/controller/groups/pipelines/pipelines.py +++ b/pkg/api/http/controller/groups/pipelines/pipelines.py @@ -2,7 +2,7 @@ from __future__ import annotations import quart -from ... import group +from .. import group @group.group_class('pipelines', '/api/v1/pipelines') diff --git a/pkg/core/app.py b/pkg/core/app.py index d8824466..2e3c9500 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -27,10 +27,7 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities -from ...pkg.rag.knowledge import RAG_Manager - - - +from pkg.rag.knowledge.RAG_Manager import RAG_Manager class Application: @@ -51,6 +48,7 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None + # TODO 移动到 pipeline 里 tool_mgr: llm_tool_mgr.ToolManager = None @@ -103,7 +101,6 @@ class Application: storage_mgr: storagemgr.StorageMgr = None - knowledge_base_service: RAG_Manager = None # ========= HTTP Services ========= @@ -117,6 +114,8 @@ class Application: bot_service: bot_service.BotService = None + knowledge_base_service: RAG_Manager = None + def __init__(self): pass @@ -152,6 +151,7 @@ class Application: name='http-api-controller', scopes=[core_entities.LifecycleControlScope.APPLICATION], ) + self.task_mgr.create_task( never_ending(), name='never-ending-task', diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 4caf18ed..8dc51e5b 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -19,7 +19,7 @@ class LifecycleControlScope(enum.Enum): APPLICATION = 'application' PLATFORM = 'platform' PLUGIN = 'plugin' - PROVIDER = 'provider' + PROVIDER = 'provider' class LauncherTypes(enum.Enum): diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 482a468b..3ba468c8 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -9,6 +9,7 @@ from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.tools import toolmgr as llm_tool_mgr +from ...rag.knowledge.RAG_Manager import RAG_Manager as knowledge_base_mgr from ...platform import botmgr as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -101,6 +102,12 @@ class BuildAppStage(stage.BootingStage): embedding_models_service_inst = model_service.EmbeddingModelsService(ap) ap.embedding_models_service = embedding_models_service_inst + knowledge_base_service_inst = knowledge_base_mgr(ap) + print("knowledge_base_service_inst1", type(knowledge_base_service_inst)) + await knowledge_base_service_inst.initialize_rag_system() + ap.knowledge_base_service = knowledge_base_service_inst + print("knowledge_base_service_inst", type(ap.knowledge_base_service)) + pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst diff --git a/pkg/entity/persistence/vector.py b/pkg/entity/persistence/vector.py new file mode 100644 index 00000000..84d1dfb1 --- /dev/null +++ b/pkg/entity/persistence/vector.py @@ -0,0 +1,14 @@ +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary +from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from datetime import datetime +import numpy as np # 用于处理从LargeBinary转换回来的embedding + +Base = declarative_base() + +class Vector(Base): + __tablename__ = 'vectors' + id = Column(Integer, primary_key=True, index=True) + chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) + embedding = Column(LargeBinary) # Store embeddings as binary + + chunk = relationship("Chunk", back_populates="vector") \ No newline at end of file diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index d85699d3..292f23ce 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -1,18 +1,24 @@ # RAG_Manager class (main class, adjust imports as needed) +from __future__ import annotations # For type hinting in Python 3.7+ import logging import os import asyncio -from services.parser import FileParser -from services.chunker import Chunker -from services.embedder import Embedder -from services.retriever import Retriever -from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly -from services.embedding_models import EmbeddingModelFactory -from services.chroma_manager import ChromaIndexManager -from ...core import app +from pkg.rag.knowledge.services.parser import FileParser +from pkg.rag.knowledge.services.chunker import Chunker +from pkg.rag.knowledge.services.embedder import Embedder +from pkg.rag.knowledge.services.retriever import Retriever +from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly +from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager +from pkg.core import app # Adjust the import path as needed + class RAG_Manager: - def __init__(self, logger: logging.Logger = None): + + ap: app.Application + + def __init__(self, ap: app.Application,logger: logging.Logger = None): + self.ap = ap self.logger = logger or logging.getLogger(__name__) self.embedding_model_type = None self.embedding_model_name = None @@ -21,11 +27,11 @@ class RAG_Manager: self.chunker = None self.embedder = None self.retriever = None - - async def initialize_system(self): + + async def initialize_rag_system(self): await asyncio.to_thread(create_db_and_tables) - async def create_model(self, embedding_model_type: str, + async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str): self.embedding_model_type = embedding_model_type self.embedding_model_name = embedding_model_name @@ -57,7 +63,7 @@ class RAG_Manager: ) - async def create_knowledge_base(self, kb_name: str, kb_description: str): + async def create_knowledge_base(self, kb_name: str, kb_description: str ,): """ Creates a new knowledge base with the given name and description. If a knowledge base with the same name already exists, it returns that one. diff --git a/pkg/rag/knowledge/services/base_service.py b/pkg/rag/knowledge/services/base_service.py index 0298226a..4ff1ce39 100644 --- a/pkg/rag/knowledge/services/base_service.py +++ b/pkg/rag/knowledge/services/base_service.py @@ -1,20 +1,20 @@ # 封装异步操作 import asyncio import logging -from services.database import SessionLocal # 导入 SessionLocal 工厂函数 +from pkg.rag.knowledge.services.database import SessionLocal class BaseService: def __init__(self): self.logger = logging.getLogger(self.__class__.__name__) - self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数 + self.db_session_factory = SessionLocal async def _run_sync(self, func, *args, **kwargs): """ 在单独的线程中运行同步函数。 如果第一个参数是 session,则在 to_thread 中获取新的 session。 """ - # 如果函数需要数据库会话作为第一个参数,我们在这里获取它 - if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头 + + if getattr(func, '__name__', '').startswith('_db_'): session = await asyncio.to_thread(self.db_session_factory) try: result = await asyncio.to_thread(func, session, *args, **kwargs) diff --git a/pkg/rag/knowledge/services/chroma_manager.py b/pkg/rag/knowledge/services/chroma_manager.py index 6a469168..f8020cdb 100644 --- a/pkg/rag/knowledge/services/chroma_manager.py +++ b/pkg/rag/knowledge/services/chroma_manager.py @@ -1,4 +1,4 @@ -# services/chroma_manager.py + import numpy as np import logging from chromadb import PersistentClient @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) class ChromaIndexManager: def __init__(self, collection_name: str = "default_collection"): self.logger = logging.getLogger(self.__class__.__name__) - chroma_data_path = "./chroma_data" + chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma")) os.makedirs(chroma_data_path, exist_ok=True) self.client = PersistentClient(path=chroma_data_path) self._collection_name = collection_name diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py index f115dac4..17202a7a 100644 --- a/pkg/rag/knowledge/services/chunker.py +++ b/pkg/rag/knowledge/services/chunker.py @@ -1,7 +1,7 @@ # services/chunker.py import logging from typing import List -from services.base_service import BaseService # Assuming BaseService provides _run_sync +from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync logger = logging.getLogger(__name__) diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 2b581e96..7e20b19a 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -4,10 +4,10 @@ import logging import numpy as np from typing import List from sqlalchemy.orm import Session -from services.base_service import BaseService -from services.database import Chunk, SessionLocal -from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory -from services.chroma_manager import ChromaIndexManager # Import the manager +from pkg.rag.knowledge.services.base_service import BaseService +from pkg.rag.knowledge.services.database import Chunk, SessionLocal +from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager logger = logging.getLogger(__name__) diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py index 6da1c5d8..4da81eb1 100644 --- a/pkg/rag/knowledge/services/retriever.py +++ b/pkg/rag/knowledge/services/retriever.py @@ -4,10 +4,10 @@ import logging import numpy as np # Make sure numpy is imported from typing import List, Dict, Any from sqlalchemy.orm import Session -from services.base_service import BaseService -from services.database import Chunk, SessionLocal -from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory -from services.chroma_manager import ChromaIndexManager +from pkg.rag.knowledge.services.base_service import BaseService +from pkg.rag.knowledge.services.database import Chunk, SessionLocal +from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager logger = logging.getLogger(__name__)