mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-09 07:16:04 +00:00
feat: add functions
This commit is contained in:
@@ -1,18 +1,24 @@
|
||||
# RAG_Manager class (main class, adjust imports as needed)
|
||||
from __future__ import annotations # For type hinting in Python 3.7+
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
from services.parser import FileParser
|
||||
from services.chunker import Chunker
|
||||
from services.embedder import Embedder
|
||||
from services.retriever import Retriever
|
||||
from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
|
||||
from services.embedding_models import EmbeddingModelFactory
|
||||
from services.chroma_manager import ChromaIndexManager
|
||||
from ...core import app
|
||||
from pkg.rag.knowledge.services.parser import FileParser
|
||||
from pkg.rag.knowledge.services.chunker import Chunker
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
|
||||
from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from pkg.core import app # Adjust the import path as needed
|
||||
|
||||
|
||||
class RAG_Manager:
|
||||
def __init__(self, logger: logging.Logger = None):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application,logger: logging.Logger = None):
|
||||
self.ap = ap
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.embedding_model_type = None
|
||||
self.embedding_model_name = None
|
||||
@@ -21,11 +27,11 @@ class RAG_Manager:
|
||||
self.chunker = None
|
||||
self.embedder = None
|
||||
self.retriever = None
|
||||
|
||||
async def initialize_system(self):
|
||||
|
||||
async def initialize_rag_system(self):
|
||||
await asyncio.to_thread(create_db_and_tables)
|
||||
|
||||
async def create_model(self, embedding_model_type: str,
|
||||
async def create_specific_model(self, embedding_model_type: str,
|
||||
embedding_model_name: str):
|
||||
self.embedding_model_type = embedding_model_type
|
||||
self.embedding_model_name = embedding_model_name
|
||||
@@ -57,7 +63,7 @@ class RAG_Manager:
|
||||
)
|
||||
|
||||
|
||||
async def create_knowledge_base(self, kb_name: str, kb_description: str):
|
||||
async def create_knowledge_base(self, kb_name: str, kb_description: str ,):
|
||||
"""
|
||||
Creates a new knowledge base with the given name and description.
|
||||
If a knowledge base with the same name already exists, it returns that one.
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# 封装异步操作
|
||||
import asyncio
|
||||
import logging
|
||||
from services.database import SessionLocal # 导入 SessionLocal 工厂函数
|
||||
from pkg.rag.knowledge.services.database import SessionLocal
|
||||
|
||||
class BaseService:
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数
|
||||
self.db_session_factory = SessionLocal
|
||||
|
||||
async def _run_sync(self, func, *args, **kwargs):
|
||||
"""
|
||||
在单独的线程中运行同步函数。
|
||||
如果第一个参数是 session,则在 to_thread 中获取新的 session。
|
||||
"""
|
||||
# 如果函数需要数据库会话作为第一个参数,我们在这里获取它
|
||||
if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头
|
||||
|
||||
if getattr(func, '__name__', '').startswith('_db_'):
|
||||
session = await asyncio.to_thread(self.db_session_factory)
|
||||
try:
|
||||
result = await asyncio.to_thread(func, session, *args, **kwargs)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# services/chroma_manager.py
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from chromadb import PersistentClient
|
||||
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
class ChromaIndexManager:
|
||||
def __init__(self, collection_name: str = "default_collection"):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
chroma_data_path = "./chroma_data"
|
||||
chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma"))
|
||||
os.makedirs(chroma_data_path, exist_ok=True)
|
||||
self.client = PersistentClient(path=chroma_data_path)
|
||||
self._collection_name = collection_name
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# services/chunker.py
|
||||
import logging
|
||||
from typing import List
|
||||
from services.base_service import BaseService # Assuming BaseService provides _run_sync
|
||||
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ import logging
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from sqlalchemy.orm import Session
|
||||
from services.base_service import BaseService
|
||||
from services.database import Chunk, SessionLocal
|
||||
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from services.chroma_manager import ChromaIndexManager # Import the manager
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ import logging
|
||||
import numpy as np # Make sure numpy is imported
|
||||
from typing import List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from services.base_service import BaseService
|
||||
from services.database import Chunk, SessionLocal
|
||||
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from services.chroma_manager import ChromaIndexManager
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user