feat: add functions

This commit is contained in:
WangCham
2025-07-03 23:28:47 +08:00
committed by Junyan Qin
parent c4671fbf1c
commit 34fe8b324d
12 changed files with 75 additions and 49 deletions

View File

@@ -1,5 +1,4 @@
import quart
from __future__ import annotations
from .. import group
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
@@ -16,13 +15,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
rag = self.ap.knowledge_base_service.RAG_Manager()
@self.route('', methods=['POST', 'GET'])
async def _() -> str:
if quart.request.method == 'GET':
knowledge_bases = await rag.get_all_knowledge_bases()
knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases()
bases_list = [
{
"uuid": kb.id,
@@ -35,17 +34,19 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
msg='ok')
json_data = await quart.request.json
knowledge_base_uuid = await rag.create_knowledge_base(
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
json_data.get('name'),
json_data.get('description')
)
return self.success()
return self.success(code=0,
data={},
msg='ok')
@self.route('/<knowledge_base_uuid>', methods=['GET'])
@self.route('/<knowledge_base_uuid>', methods=['GET','DELETE'])
async def _(knowledge_base_uuid: str) -> str:
if quart.request.method == 'GET':
knowledge_base = await rag.get_knowledge_base_by_id(knowledge_base_uuid)
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid)
if knowledge_base is None:
return self.http_status(404, -1, 'knowledge base not found')
@@ -59,11 +60,14 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
},
msg='ok'
)
elif quart.request.method == 'DELETE':
await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid)
return self.success(code=0, msg='ok')
@self.route('/<knowledge_base_uuid>/files', methods=['GET'])
async def _(knowledge_base_uuid: str) -> str:
if quart.request.method == 'GET':
files = await rag.get_files_by_knowledge_base(knowledge_base_uuid)
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid)
return self.success(code=0,data=[{
"id": file.id,
"file_name": file.file_name,
@@ -73,11 +77,6 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
# delete specific file in knowledge base
@self.route('/<knowledge_base_uuid>/files/<file_id>', methods=['DELETE'])
async def _(knowledge_base_uuid: str, file_id: str) -> str:
await rag.delete_data_by_file_id(file_id)
await self.ap.knowledge_base_service.delete_data_by_file_id(file_id)
return self.success(code=0, msg='ok')
# delete specific kb
@self.route('/<knowledge_base_uuid>', methods=['DELETE'])
async def _(knowledge_base_uuid: str) -> str:
await rag.delete_kb_by_id(knowledge_base_uuid)
return self.success(code=0, msg='ok')

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import quart
from ... import group
from .. import group
@group.group_class('pipelines', '/api/v1/pipelines')

View File

@@ -27,10 +27,7 @@ from ..storage import mgr as storagemgr
from ..utils import logcache
from . import taskmgr
from . import entities as core_entities
from ...pkg.rag.knowledge import RAG_Manager
from pkg.rag.knowledge.RAG_Manager import RAG_Manager
class Application:
@@ -51,6 +48,7 @@ class Application:
model_mgr: llm_model_mgr.ModelManager = None
# TODO 移动到 pipeline 里
tool_mgr: llm_tool_mgr.ToolManager = None
@@ -103,7 +101,6 @@ class Application:
storage_mgr: storagemgr.StorageMgr = None
knowledge_base_service: RAG_Manager = None
# ========= HTTP Services =========
@@ -117,6 +114,8 @@ class Application:
bot_service: bot_service.BotService = None
knowledge_base_service: RAG_Manager = None
def __init__(self):
pass
@@ -152,6 +151,7 @@ class Application:
name='http-api-controller',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
self.task_mgr.create_task(
never_ending(),
name='never-ending-task',

View File

@@ -19,7 +19,7 @@ class LifecycleControlScope(enum.Enum):
APPLICATION = 'application'
PLATFORM = 'platform'
PLUGIN = 'plugin'
PROVIDER = 'provider'
PROVIDER = 'provider'
class LauncherTypes(enum.Enum):

View File

@@ -9,6 +9,7 @@ from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...rag.knowledge.RAG_Manager import RAG_Manager as knowledge_base_mgr
from ...platform import botmgr as im_mgr
from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
@@ -101,6 +102,12 @@ class BuildAppStage(stage.BootingStage):
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
ap.embedding_models_service = embedding_models_service_inst
knowledge_base_service_inst = knowledge_base_mgr(ap)
print("knowledge_base_service_inst1", type(knowledge_base_service_inst))
await knowledge_base_service_inst.initialize_rag_system()
ap.knowledge_base_service = knowledge_base_service_inst
print("knowledge_base_service_inst", type(ap.knowledge_base_service))
pipeline_service_inst = pipeline_service.PipelineService(ap)
ap.pipeline_service = pipeline_service_inst

View File

@@ -0,0 +1,14 @@
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
from datetime import datetime
import numpy as np # 用于处理从LargeBinary转换回来的embedding
Base = declarative_base()
class Vector(Base):
__tablename__ = 'vectors'
id = Column(Integer, primary_key=True, index=True)
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
embedding = Column(LargeBinary) # Store embeddings as binary
chunk = relationship("Chunk", back_populates="vector")

View File

@@ -1,18 +1,24 @@
# RAG_Manager class (main class, adjust imports as needed)
from __future__ import annotations # For type hinting in Python 3.7+
import logging
import os
import asyncio
from services.parser import FileParser
from services.chunker import Chunker
from services.embedder import Embedder
from services.retriever import Retriever
from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
from services.embedding_models import EmbeddingModelFactory
from services.chroma_manager import ChromaIndexManager
from ...core import app
from pkg.rag.knowledge.services.parser import FileParser
from pkg.rag.knowledge.services.chunker import Chunker
from pkg.rag.knowledge.services.embedder import Embedder
from pkg.rag.knowledge.services.retriever import Retriever
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
from pkg.core import app # Adjust the import path as needed
class RAG_Manager:
def __init__(self, logger: logging.Logger = None):
ap: app.Application
def __init__(self, ap: app.Application,logger: logging.Logger = None):
self.ap = ap
self.logger = logger or logging.getLogger(__name__)
self.embedding_model_type = None
self.embedding_model_name = None
@@ -21,11 +27,11 @@ class RAG_Manager:
self.chunker = None
self.embedder = None
self.retriever = None
async def initialize_system(self):
async def initialize_rag_system(self):
await asyncio.to_thread(create_db_and_tables)
async def create_model(self, embedding_model_type: str,
async def create_specific_model(self, embedding_model_type: str,
embedding_model_name: str):
self.embedding_model_type = embedding_model_type
self.embedding_model_name = embedding_model_name
@@ -57,7 +63,7 @@ class RAG_Manager:
)
async def create_knowledge_base(self, kb_name: str, kb_description: str):
async def create_knowledge_base(self, kb_name: str, kb_description: str ,):
"""
Creates a new knowledge base with the given name and description.
If a knowledge base with the same name already exists, it returns that one.

View File

@@ -1,20 +1,20 @@
# 封装异步操作
import asyncio
import logging
from services.database import SessionLocal # 导入 SessionLocal 工厂函数
from pkg.rag.knowledge.services.database import SessionLocal
class BaseService:
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数
self.db_session_factory = SessionLocal
async def _run_sync(self, func, *args, **kwargs):
"""
在单独的线程中运行同步函数。
如果第一个参数是 session则在 to_thread 中获取新的 session。
"""
# 如果函数需要数据库会话作为第一个参数,我们在这里获取它
if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头
if getattr(func, '__name__', '').startswith('_db_'):
session = await asyncio.to_thread(self.db_session_factory)
try:
result = await asyncio.to_thread(func, session, *args, **kwargs)

View File

@@ -1,4 +1,4 @@
# services/chroma_manager.py
import numpy as np
import logging
from chromadb import PersistentClient
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
class ChromaIndexManager:
def __init__(self, collection_name: str = "default_collection"):
self.logger = logging.getLogger(self.__class__.__name__)
chroma_data_path = "./chroma_data"
chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma"))
os.makedirs(chroma_data_path, exist_ok=True)
self.client = PersistentClient(path=chroma_data_path)
self._collection_name = collection_name

View File

@@ -1,7 +1,7 @@
# services/chunker.py
import logging
from typing import List
from services.base_service import BaseService # Assuming BaseService provides _run_sync
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
logger = logging.getLogger(__name__)

View File

@@ -4,10 +4,10 @@ import logging
import numpy as np
from typing import List
from sqlalchemy.orm import Session
from services.base_service import BaseService
from services.database import Chunk, SessionLocal
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from services.chroma_manager import ChromaIndexManager # Import the manager
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager
logger = logging.getLogger(__name__)

View File

@@ -4,10 +4,10 @@ import logging
import numpy as np # Make sure numpy is imported
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from services.base_service import BaseService
from services.database import Chunk, SessionLocal
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from services.chroma_manager import ChromaIndexManager
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
logger = logging.getLogger(__name__)