feat: add functions

This commit is contained in:
WangCham
2025-07-03 23:28:47 +08:00
committed by Junyan Qin
parent c4671fbf1c
commit 34fe8b324d
12 changed files with 75 additions and 49 deletions

View File

@@ -1,20 +1,20 @@
# 封装异步操作
import asyncio
import logging
from services.database import SessionLocal # 导入 SessionLocal 工厂函数
from pkg.rag.knowledge.services.database import SessionLocal
class BaseService:
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数
self.db_session_factory = SessionLocal
async def _run_sync(self, func, *args, **kwargs):
"""
在单独的线程中运行同步函数。
如果第一个参数是 session则在 to_thread 中获取新的 session。
"""
# 如果函数需要数据库会话作为第一个参数,我们在这里获取它
if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头
if getattr(func, '__name__', '').startswith('_db_'):
session = await asyncio.to_thread(self.db_session_factory)
try:
result = await asyncio.to_thread(func, session, *args, **kwargs)

View File

@@ -1,4 +1,4 @@
# services/chroma_manager.py
import numpy as np
import logging
from chromadb import PersistentClient
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
class ChromaIndexManager:
def __init__(self, collection_name: str = "default_collection"):
self.logger = logging.getLogger(self.__class__.__name__)
chroma_data_path = "./chroma_data"
chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma"))
os.makedirs(chroma_data_path, exist_ok=True)
self.client = PersistentClient(path=chroma_data_path)
self._collection_name = collection_name

View File

@@ -1,7 +1,7 @@
# services/chunker.py
import logging
from typing import List
from services.base_service import BaseService # Assuming BaseService provides _run_sync
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
logger = logging.getLogger(__name__)

View File

@@ -4,10 +4,10 @@ import logging
import numpy as np
from typing import List
from sqlalchemy.orm import Session
from services.base_service import BaseService
from services.database import Chunk, SessionLocal
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from services.chroma_manager import ChromaIndexManager # Import the manager
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager
logger = logging.getLogger(__name__)

View File

@@ -4,10 +4,10 @@ import logging
import numpy as np # Make sure numpy is imported
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from services.base_service import BaseService
from services.database import Chunk, SessionLocal
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from services.chroma_manager import ChromaIndexManager
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
logger = logging.getLogger(__name__)