mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: add functions
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import quart
|
||||
from __future__ import annotations
|
||||
from .. import group
|
||||
|
||||
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
|
||||
@@ -16,13 +15,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
|
||||
|
||||
async def initialize(self) -> None:
|
||||
rag = self.ap.knowledge_base_service.RAG_Manager()
|
||||
|
||||
|
||||
@self.route('', methods=['POST', 'GET'])
|
||||
async def _() -> str:
|
||||
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_bases = await rag.get_all_knowledge_bases()
|
||||
knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases()
|
||||
bases_list = [
|
||||
{
|
||||
"uuid": kb.id,
|
||||
@@ -35,17 +34,19 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
msg='ok')
|
||||
|
||||
json_data = await quart.request.json
|
||||
knowledge_base_uuid = await rag.create_knowledge_base(
|
||||
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
|
||||
json_data.get('name'),
|
||||
json_data.get('description')
|
||||
)
|
||||
return self.success()
|
||||
return self.success(code=0,
|
||||
data={},
|
||||
msg='ok')
|
||||
|
||||
|
||||
@self.route('/<knowledge_base_uuid>', methods=['GET'])
|
||||
@self.route('/<knowledge_base_uuid>', methods=['GET','DELETE'])
|
||||
async def _(knowledge_base_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_base = await rag.get_knowledge_base_by_id(knowledge_base_uuid)
|
||||
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid)
|
||||
|
||||
if knowledge_base is None:
|
||||
return self.http_status(404, -1, 'knowledge base not found')
|
||||
@@ -59,11 +60,14 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
},
|
||||
msg='ok'
|
||||
)
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid)
|
||||
return self.success(code=0, msg='ok')
|
||||
|
||||
@self.route('/<knowledge_base_uuid>/files', methods=['GET'])
|
||||
async def _(knowledge_base_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
files = await rag.get_files_by_knowledge_base(knowledge_base_uuid)
|
||||
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid)
|
||||
return self.success(code=0,data=[{
|
||||
"id": file.id,
|
||||
"file_name": file.file_name,
|
||||
@@ -73,11 +77,6 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
# delete specific file in knowledge base
|
||||
@self.route('/<knowledge_base_uuid>/files/<file_id>', methods=['DELETE'])
|
||||
async def _(knowledge_base_uuid: str, file_id: str) -> str:
|
||||
await rag.delete_data_by_file_id(file_id)
|
||||
await self.ap.knowledge_base_service.delete_data_by_file_id(file_id)
|
||||
return self.success(code=0, msg='ok')
|
||||
|
||||
# delete specific kb
|
||||
@self.route('/<knowledge_base_uuid>', methods=['DELETE'])
|
||||
async def _(knowledge_base_uuid: str) -> str:
|
||||
await rag.delete_kb_by_id(knowledge_base_uuid)
|
||||
return self.success(code=0, msg='ok')
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
from .. import group
|
||||
|
||||
|
||||
@group.group_class('pipelines', '/api/v1/pipelines')
|
||||
|
||||
@@ -27,10 +27,7 @@ from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
from . import taskmgr
|
||||
from . import entities as core_entities
|
||||
from ...pkg.rag.knowledge import RAG_Manager
|
||||
|
||||
|
||||
|
||||
from pkg.rag.knowledge.RAG_Manager import RAG_Manager
|
||||
|
||||
|
||||
class Application:
|
||||
@@ -51,6 +48,7 @@ class Application:
|
||||
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
||||
|
||||
# TODO 移动到 pipeline 里
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
@@ -103,7 +101,6 @@ class Application:
|
||||
|
||||
storage_mgr: storagemgr.StorageMgr = None
|
||||
|
||||
knowledge_base_service: RAG_Manager = None
|
||||
|
||||
# ========= HTTP Services =========
|
||||
|
||||
@@ -117,6 +114,8 @@ class Application:
|
||||
|
||||
bot_service: bot_service.BotService = None
|
||||
|
||||
knowledge_base_service: RAG_Manager = None
|
||||
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -152,6 +151,7 @@ class Application:
|
||||
name='http-api-controller',
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
self.task_mgr.create_task(
|
||||
never_ending(),
|
||||
name='never-ending-task',
|
||||
|
||||
@@ -19,7 +19,7 @@ class LifecycleControlScope(enum.Enum):
|
||||
APPLICATION = 'application'
|
||||
PLATFORM = 'platform'
|
||||
PLUGIN = 'plugin'
|
||||
PROVIDER = 'provider'
|
||||
PROVIDER = 'provider'
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
|
||||
@@ -9,6 +9,7 @@ from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...rag.knowledge.RAG_Manager import RAG_Manager as knowledge_base_mgr
|
||||
from ...platform import botmgr as im_mgr
|
||||
from ...persistence import mgr as persistencemgr
|
||||
from ...api.http.controller import main as http_controller
|
||||
@@ -101,6 +102,12 @@ class BuildAppStage(stage.BootingStage):
|
||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||
ap.embedding_models_service = embedding_models_service_inst
|
||||
|
||||
knowledge_base_service_inst = knowledge_base_mgr(ap)
|
||||
print("knowledge_base_service_inst1", type(knowledge_base_service_inst))
|
||||
await knowledge_base_service_inst.initialize_rag_system()
|
||||
ap.knowledge_base_service = knowledge_base_service_inst
|
||||
print("knowledge_base_service_inst", type(ap.knowledge_base_service))
|
||||
|
||||
pipeline_service_inst = pipeline_service.PipelineService(ap)
|
||||
ap.pipeline_service = pipeline_service_inst
|
||||
|
||||
|
||||
14
pkg/entity/persistence/vector.py
Normal file
14
pkg/entity/persistence/vector.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
|
||||
from datetime import datetime
|
||||
import numpy as np # 用于处理从LargeBinary转换回来的embedding
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
|
||||
embedding = Column(LargeBinary) # Store embeddings as binary
|
||||
|
||||
chunk = relationship("Chunk", back_populates="vector")
|
||||
@@ -1,18 +1,24 @@
|
||||
# RAG_Manager class (main class, adjust imports as needed)
|
||||
from __future__ import annotations # For type hinting in Python 3.7+
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
from services.parser import FileParser
|
||||
from services.chunker import Chunker
|
||||
from services.embedder import Embedder
|
||||
from services.retriever import Retriever
|
||||
from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
|
||||
from services.embedding_models import EmbeddingModelFactory
|
||||
from services.chroma_manager import ChromaIndexManager
|
||||
from ...core import app
|
||||
from pkg.rag.knowledge.services.parser import FileParser
|
||||
from pkg.rag.knowledge.services.chunker import Chunker
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
|
||||
from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from pkg.core import app # Adjust the import path as needed
|
||||
|
||||
|
||||
class RAG_Manager:
|
||||
def __init__(self, logger: logging.Logger = None):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application,logger: logging.Logger = None):
|
||||
self.ap = ap
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.embedding_model_type = None
|
||||
self.embedding_model_name = None
|
||||
@@ -21,11 +27,11 @@ class RAG_Manager:
|
||||
self.chunker = None
|
||||
self.embedder = None
|
||||
self.retriever = None
|
||||
|
||||
async def initialize_system(self):
|
||||
|
||||
async def initialize_rag_system(self):
|
||||
await asyncio.to_thread(create_db_and_tables)
|
||||
|
||||
async def create_model(self, embedding_model_type: str,
|
||||
async def create_specific_model(self, embedding_model_type: str,
|
||||
embedding_model_name: str):
|
||||
self.embedding_model_type = embedding_model_type
|
||||
self.embedding_model_name = embedding_model_name
|
||||
@@ -57,7 +63,7 @@ class RAG_Manager:
|
||||
)
|
||||
|
||||
|
||||
async def create_knowledge_base(self, kb_name: str, kb_description: str):
|
||||
async def create_knowledge_base(self, kb_name: str, kb_description: str ,):
|
||||
"""
|
||||
Creates a new knowledge base with the given name and description.
|
||||
If a knowledge base with the same name already exists, it returns that one.
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# 封装异步操作
|
||||
import asyncio
|
||||
import logging
|
||||
from services.database import SessionLocal # 导入 SessionLocal 工厂函数
|
||||
from pkg.rag.knowledge.services.database import SessionLocal
|
||||
|
||||
class BaseService:
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数
|
||||
self.db_session_factory = SessionLocal
|
||||
|
||||
async def _run_sync(self, func, *args, **kwargs):
|
||||
"""
|
||||
在单独的线程中运行同步函数。
|
||||
如果第一个参数是 session,则在 to_thread 中获取新的 session。
|
||||
"""
|
||||
# 如果函数需要数据库会话作为第一个参数,我们在这里获取它
|
||||
if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头
|
||||
|
||||
if getattr(func, '__name__', '').startswith('_db_'):
|
||||
session = await asyncio.to_thread(self.db_session_factory)
|
||||
try:
|
||||
result = await asyncio.to_thread(func, session, *args, **kwargs)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# services/chroma_manager.py
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from chromadb import PersistentClient
|
||||
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
class ChromaIndexManager:
|
||||
def __init__(self, collection_name: str = "default_collection"):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
chroma_data_path = "./chroma_data"
|
||||
chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma"))
|
||||
os.makedirs(chroma_data_path, exist_ok=True)
|
||||
self.client = PersistentClient(path=chroma_data_path)
|
||||
self._collection_name = collection_name
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# services/chunker.py
|
||||
import logging
|
||||
from typing import List
|
||||
from services.base_service import BaseService # Assuming BaseService provides _run_sync
|
||||
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ import logging
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from sqlalchemy.orm import Session
|
||||
from services.base_service import BaseService
|
||||
from services.database import Chunk, SessionLocal
|
||||
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from services.chroma_manager import ChromaIndexManager # Import the manager
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ import logging
|
||||
import numpy as np # Make sure numpy is imported
|
||||
from typing import List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from services.base_service import BaseService
|
||||
from services.database import Chunk, SessionLocal
|
||||
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from services.chroma_manager import ChromaIndexManager
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user