From c81d5a1a49194e2e806801f53c32b405f12627e7 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 21 May 2025 12:42:39 +0800 Subject: [PATCH 1/7] feat: add embeddings model management (#1461) * feat: add embeddings model management backend support Co-Authored-By: Junyan Qin * feat: add embeddings model management frontend support Co-Authored-By: Junyan Qin * chore: revert HttpClient URL to production setting Co-Authored-By: Junyan Qin * refactor: integrate embeddings models into models page with tabs Co-Authored-By: Junyan Qin * perf: move files * perf: remove `s` * feat: allow requester to declare supported types in manifest * feat(embedding): delete dimension and encoding format * feat: add extra_args for embedding moels * perf: i18n ref * fix: linter err * fix: lint err * fix: linter err --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Junyan Qin --- .../http/controller/groups/provider/models.py | 55 +- .../controller/groups/provider/requesters.py | 3 +- pkg/api/http/service/model.py | 89 ++- pkg/core/app.py | 4 +- pkg/core/stages/build_app.py | 7 +- pkg/entity/persistence/model.py | 21 + pkg/provider/modelmgr/entities.py | 2 +- pkg/provider/modelmgr/modelmgr.py | 89 ++- pkg/provider/modelmgr/requester.py | 51 +- .../modelmgr/requesters/anthropicmsgs.py | 2 +- .../modelmgr/requesters/anthropicmsgs.yaml | 2 + .../modelmgr/requesters/bailianchatcmpl.yaml | 2 + pkg/provider/modelmgr/requesters/chatcmpl.py | 38 +- .../modelmgr/requesters/chatcmpl.yaml | 3 + .../modelmgr/requesters/deepseekchatcmpl.yaml | 2 + .../modelmgr/requesters/geminichatcmpl.yaml | 2 + .../modelmgr/requesters/giteeaichatcmpl.yaml | 2 + .../modelmgr/requesters/lmstudiochatcmpl.yaml | 2 + .../modelmgr/requesters/modelscopechatcmpl.py | 2 +- .../requesters/modelscopechatcmpl.yaml | 2 + .../modelmgr/requesters/moonshotchatcmpl.yaml | 2 + .../modelmgr/requesters/ollamachat.py | 2 +- .../modelmgr/requesters/ollamachat.yaml | 2 + .../requesters/openrouterchatcmpl.yaml | 2 + .../modelmgr/requesters/ppiochatcmpl.yaml | 2 + .../requesters/siliconflowchatcmpl.yaml | 2 + .../modelmgr/requesters/volcarkchatcmpl.yaml | 2 + .../modelmgr/requesters/xaichatcmpl.yaml | 2 + .../modelmgr/requesters/zhipuaichatcmpl.yaml | 2 + .../home-sidebar/sidbarConfigList.tsx | 1 + .../{llm-form => }/ChooseRequesterEntity.ts | 0 .../models/component/ICreateEmbeddingField.ts | 7 + .../models/{ => component}/ICreateLLMField.ts | 0 .../embedding-card/EmbeddingCard.module.css | 97 +++ .../embedding-card/EmbeddingCard.tsx | 53 ++ .../embedding-card/EmbeddingCardVO.ts | 23 + .../embedding-form/EmbeddingForm.tsx | 563 ++++++++++++++++++ .../models/component/llm-form/LLMForm.tsx | 8 +- web/src/app/home/models/page.tsx | 183 +++++- web/src/app/infra/entities/api/index.ts | 23 + web/src/app/infra/http/HttpClient.ts | 42 +- web/src/i18n/locales/en-US.ts | 18 +- web/src/i18n/locales/zh-Hans.ts | 18 +- 43 files changed, 1370 insertions(+), 64 deletions(-) rename web/src/app/home/models/component/{llm-form => }/ChooseRequesterEntity.ts (100%) create mode 100644 web/src/app/home/models/component/ICreateEmbeddingField.ts rename web/src/app/home/models/{ => component}/ICreateLLMField.ts (100%) create mode 100644 web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css create mode 100644 web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx create mode 100644 web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts create mode 100644 web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx diff --git a/pkg/api/http/controller/groups/provider/models.py b/pkg/api/http/controller/groups/provider/models.py index bb77986c..0de0c922 100644 --- a/pkg/api/http/controller/groups/provider/models.py +++ b/pkg/api/http/controller/groups/provider/models.py @@ -9,18 +9,18 @@ class LLMModelsRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={'models': await self.ap.model_service.get_llm_models()}) + return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()}) elif quart.request.method == 'POST': json_data = await quart.request.json - model_uuid = await self.ap.model_service.create_llm_model(json_data) + model_uuid = await self.ap.llm_model_service.create_llm_model(json_data) return self.success(data={'uuid': model_uuid}) @self.route('/', methods=['GET', 'PUT', 'DELETE']) async def _(model_uuid: str) -> str: if quart.request.method == 'GET': - model = await self.ap.model_service.get_llm_model(model_uuid) + model = await self.ap.llm_model_service.get_llm_model(model_uuid) if model is None: return self.http_status(404, -1, 'model not found') @@ -29,11 +29,11 @@ class LLMModelsRouterGroup(group.RouterGroup): elif quart.request.method == 'PUT': json_data = await quart.request.json - await self.ap.model_service.update_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.update_llm_model(model_uuid, json_data) return self.success() elif quart.request.method == 'DELETE': - await self.ap.model_service.delete_llm_model(model_uuid) + await self.ap.llm_model_service.delete_llm_model(model_uuid) return self.success() @@ -41,6 +41,49 @@ class LLMModelsRouterGroup(group.RouterGroup): async def _(model_uuid: str) -> str: json_data = await quart.request.json - await self.ap.model_service.test_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.test_llm_model(model_uuid, json_data) + + return self.success() + + +@group.group_class('models/embedding', '/api/v1/provider/models/embedding') +class EmbeddingModelsRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('', methods=['GET', 'POST']) + async def _() -> str: + if quart.request.method == 'GET': + return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + + model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data) + + return self.success(data={'uuid': model_uuid}) + + @self.route('/', methods=['GET', 'PUT', 'DELETE']) + async def _(model_uuid: str) -> str: + if quart.request.method == 'GET': + model = await self.ap.embedding_models_service.get_embedding_model(model_uuid) + + if model is None: + return self.http_status(404, -1, 'model not found') + + return self.success(data={'model': model}) + elif quart.request.method == 'PUT': + json_data = await quart.request.json + + await self.ap.embedding_models_service.update_embedding_model(model_uuid, json_data) + + return self.success() + elif quart.request.method == 'DELETE': + await self.ap.embedding_models_service.delete_embedding_model(model_uuid) + + return self.success() + + @self.route('//test', methods=['POST']) + async def _(model_uuid: str) -> str: + json_data = await quart.request.json + + await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data) return self.success() diff --git a/pkg/api/http/controller/groups/provider/requesters.py b/pkg/api/http/controller/groups/provider/requesters.py index 0f999288..af9e1540 100644 --- a/pkg/api/http/controller/groups/provider/requesters.py +++ b/pkg/api/http/controller/groups/provider/requesters.py @@ -8,7 +8,8 @@ class RequestersRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> quart.Response: - return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()}) + model_type = quart.request.args.get('type', '') + return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info(model_type)}) @self.route('/', methods=['GET']) async def _(requester_name: str) -> quart.Response: diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 74fb4e02..afeae3eb 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -10,7 +10,7 @@ from ....provider.modelmgr import requester as model_requester from ....provider import entities as llm_entities -class ModelsService: +class LLMModelsService: ap: app.Application def __init__(self, ap: app.Application) -> None: @@ -103,3 +103,90 @@ class ModelsService: funcs=[], extra_args={}, ) + + +class EmbeddingModelsService: + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_embedding_models(self) -> list[dict]: + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + + models = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models] + + async def create_embedding_model(self, model_data: dict) -> str: + model_data['uuid'] = str(uuid.uuid4()) + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data) + ) + + embedding_model = await self.get_embedding_model(model_data['uuid']) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + return model_data['uuid'] + + async def get_embedding_model(self, model_uuid: str) -> dict | None: + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + model = result.first() + + if model is None: + return None + + return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) + + async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None: + if 'uuid' in model_data: + del model_data['uuid'] + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_model.EmbeddingModel) + .where(persistence_model.EmbeddingModel.uuid == model_uuid) + .values(**model_data) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + embedding_model = await self.get_embedding_model(model_uuid) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + async def delete_embedding_model(self, model_uuid: str) -> None: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None: + runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None + + if model_uuid != '_': + for model in self.ap.model_mgr.embedding_models: + if model.model_entity.uuid == model_uuid: + runtime_embedding_model = model + break + + if runtime_embedding_model is None: + raise Exception('model not found') + + else: + runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data) + + await runtime_embedding_model.requester.invoke_embedding( + query=None, + model=runtime_embedding_model, + input_text='Hello, world!', + extra_args={}, + ) diff --git a/pkg/core/app.py b/pkg/core/app.py index 911acd3d..318cddcb 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -103,7 +103,9 @@ class Application: user_service: user_service.UserService = None - model_service: model_service.ModelsService = None + llm_model_service: model_service.LLMModelsService = None + + embedding_models_service: model_service.EmbeddingModelsService = None pipeline_service: pipeline_service.PipelineService = None diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 6ee35610..482a468b 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -95,8 +95,11 @@ class BuildAppStage(stage.BootingStage): user_service_inst = user_service.UserService(ap) ap.user_service = user_service_inst - model_service_inst = model_service.ModelsService(ap) - ap.model_service = model_service_inst + llm_model_service_inst = model_service.LLMModelsService(ap) + ap.llm_model_service = llm_model_service_inst + + embedding_models_service_inst = model_service.EmbeddingModelsService(ap) + ap.embedding_models_service = embedding_models_service_inst pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst diff --git a/pkg/entity/persistence/model.py b/pkg/entity/persistence/model.py index 9eb2ccef..418cab70 100644 --- a/pkg/entity/persistence/model.py +++ b/pkg/entity/persistence/model.py @@ -23,3 +23,24 @@ class LLMModel(Base): server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now(), ) + + +class EmbeddingModel(Base): + """Embedding 模型""" + + __tablename__ = 'embedding_models' + + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) + extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/provider/modelmgr/entities.py b/pkg/provider/modelmgr/entities.py index cf856894..7bc02a32 100644 --- a/pkg/provider/modelmgr/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel): token_mgr: token.TokenManager - requester: requester.LLMAPIRequester + requester: requester.ProviderAPIRequester tool_call_supported: typing.Optional[bool] = False diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index b15e53a9..2c92eacc 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -18,7 +18,7 @@ class ModelManager: model_list: list[entities.LLMModelInfo] # deprecated - requesters: dict[str, requester.LLMAPIRequester] # deprecated + requesters: dict[str, requester.ProviderAPIRequester] # deprecated token_mgrs: dict[str, token.TokenManager] # deprecated @@ -28,9 +28,11 @@ class ModelManager: llm_models: list[requester.RuntimeLLMModel] + embedding_models: list[requester.RuntimeEmbeddingModel] + requester_components: list[engine.Component] - requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache + requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache def __init__(self, ap: app.Application): self.ap = ap @@ -38,6 +40,7 @@ class ModelManager: self.requesters = {} self.token_mgrs = {} self.llm_models = [] + self.embedding_models = [] self.requester_components = [] self.requester_dict = {} @@ -45,7 +48,7 @@ class ModelManager: self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') # forge requester class dict - requester_dict: dict[str, type[requester.LLMAPIRequester]] = {} + requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {} for component in self.requester_components: requester_dict[component.metadata.name] = component.get_python_component_class() @@ -58,13 +61,11 @@ class ModelManager: self.ap.logger.info('Loading models from db...') self.llm_models = [] + self.embedding_models = [] # llm models result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) - llm_models = result.all() - - # load models for llm_model in llm_models: try: await self.load_llm_model(llm_model) @@ -73,11 +74,17 @@ class ModelManager: except Exception as e: self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}') + # embedding models + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + embedding_models = result.all() + for embedding_model in embedding_models: + await self.load_embedding_model(embedding_model) + async def init_runtime_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """初始化运行时模型""" + """初始化运行时 LLM 模型""" if isinstance(model_info, sqlalchemy.Row): model_info = persistence_model.LLMModel(**model_info._mapping) elif isinstance(model_info, dict): @@ -101,14 +108,47 @@ class ModelManager: return runtime_llm_model + async def init_runtime_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """初始化运行时 Embedding 模型""" + if isinstance(model_info, sqlalchemy.Row): + model_info = persistence_model.EmbeddingModel(**model_info._mapping) + elif isinstance(model_info, dict): + model_info = persistence_model.EmbeddingModel(**model_info) + + requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config) + + await requester_inst.initialize() + + runtime_embedding_model = requester.RuntimeEmbeddingModel( + model_entity=model_info, + token_mgr=token.TokenManager( + name=model_info.uuid, + tokens=model_info.api_keys, + ), + requester=requester_inst, + ) + + return runtime_embedding_model + async def load_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """加载模型""" + """加载 LLM 模型""" runtime_llm_model = await self.init_runtime_llm_model(model_info) self.llm_models.append(runtime_llm_model) + async def load_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """加载 Embedding 模型""" + runtime_embedding_model = await self.init_runtime_embedding_model(model_info) + self.embedding_models.append(runtime_embedding_model) + async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated """通过名称获取模型""" for model in self.model_list: @@ -116,23 +156,44 @@ class ModelManager: return model raise ValueError(f'无法确定模型 {name} 的信息') - async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo: - """通过uuid获取模型""" + async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel: + """通过uuid获取 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == uuid: return model - raise ValueError(f'model {uuid} not found') + raise ValueError(f'LLM model {uuid} not found') + + async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel: + """通过uuid获取 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == uuid: + return model + raise ValueError(f'Embedding model {uuid} not found') async def remove_llm_model(self, model_uuid: str): - """移除模型""" + """移除 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == model_uuid: self.llm_models.remove(model) return - def get_available_requesters_info(self) -> list[dict]: + async def remove_embedding_model(self, model_uuid: str): + """移除 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == model_uuid: + self.embedding_models.remove(model) + return + + def get_available_requesters_info(self, model_type: str) -> list[dict]: """获取所有可用的请求器""" - return [component.to_plain_dict() for component in self.requester_components] + if model_type != '': + return [ + component.to_plain_dict() + for component in self.requester_components + if model_type in component.spec['support_type'] + ] + else: + return [component.to_plain_dict() for component in self.requester_components] def get_available_requester_info_by_name(self, name: str) -> dict | None: """通过名称获取请求器信息""" diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 244f4c82..9742a52c 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -20,22 +20,45 @@ class RuntimeLLMModel: token_mgr: token.TokenManager """api key管理器""" - requester: LLMAPIRequester + requester: ProviderAPIRequester """请求器实例""" def __init__( self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, - requester: LLMAPIRequester, + requester: ProviderAPIRequester, ): self.model_entity = model_entity self.token_mgr = token_mgr self.requester = requester -class LLMAPIRequester(metaclass=abc.ABCMeta): - """LLM API请求器""" +class RuntimeEmbeddingModel: + """运行时 Embedding 模型""" + + model_entity: persistence_model.EmbeddingModel + """模型数据""" + + token_mgr: token.TokenManager + """api key管理器""" + + requester: ProviderAPIRequester + """请求器实例""" + + def __init__( + self, + model_entity: persistence_model.EmbeddingModel, + token_mgr: token.TokenManager, + requester: ProviderAPIRequester, + ): + self.model_entity = model_entity + self.token_mgr = token_mgr + self.requester = requester + + +class ProviderAPIRequester(metaclass=abc.ABCMeta): + """Provider API请求器""" name: str = None @@ -74,3 +97,23 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): llm_entities.Message: 返回消息对象 """ pass + + async def invoke_embedding( + self, + query: core_entities.Query, + model: RuntimeEmbeddingModel, + input_text: str, + extra_args: dict[str, typing.Any] = {}, + ) -> list[float]: + """调用 Embedding API + + Args: + query (core_entities.Query): 请求上下文 + model (RuntimeEmbeddingModel): 使用的模型信息 + input_text (str): 输入文本 + extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. + + Returns: + list[float]: 返回的 embedding 向量 + """ + pass diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 38573854..b195ae51 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -15,7 +15,7 @@ from ...tools import entities as tools_entities from ....utils import image -class AnthropicMessages(requester.LLMAPIRequester): +class AnthropicMessages(requester.ProviderAPIRequester): """Anthropic Messages API 请求器""" client: anthropic.AsyncAnthropic diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml index c124fed9..7dbcf3ed 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./anthropicmsgs.py diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml index 24beb915..10aae30f 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./bailianchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 513086e5..98d1f13a 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -13,7 +13,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class OpenAIChatCompletions(requester.LLMAPIRequester): +class OpenAIChatCompletions(requester.ProviderAPIRequester): """OpenAI ChatCompletion API 请求器""" client: openai.AsyncClient @@ -141,3 +141,39 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def invoke_embedding( + self, + query: core_entities.Query, + model: requester.RuntimeEmbeddingModel, + input_text: str, + extra_args: dict[str, typing.Any] = {}, + ) -> list[float]: + """调用 Embedding API""" + self.client.api_key = model.token_mgr.get_token() + + args = { + 'model': model.model_entity.name, + 'input': input_text, + } + + if model.model_entity.extra_args: + args.update(model.model_entity.extra_args) + + args.update(extra_args) + + try: + resp = await self.client.embeddings.create(**args) + return resp.data[0].embedding + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + except openai.BadRequestError as e: + raise errors.RequesterError(f'请求参数错误: {e.message}') + except openai.AuthenticationError as e: + raise errors.RequesterError(f'无效的 api-key: {e.message}') + except openai.NotFoundError as e: + raise errors.RequesterError(f'请求路径错误: {e.message}') + except openai.RateLimitError as e: + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') + except openai.APIError as e: + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.yaml b/pkg/provider/modelmgr/requesters/chatcmpl.yaml index 908b30ac..ff0de6f9 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/chatcmpl.yaml @@ -22,6 +22,9 @@ spec: type: integer required: true default: 120 + support_type: + - llm + - text-embedding execution: python: path: ./chatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml index ea2c7eea..6f320e66 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./deepseekchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml index 6bfc085e..73fca19c 100644 --- a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./geminichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml index a18675a1..3a79bb49 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./giteeaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml index 893235b2..fbe57dad 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./lmstudiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index b8868f4d..4708f671 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -14,7 +14,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class ModelScopeChatCompletions(requester.LLMAPIRequester): +class ModelScopeChatCompletions(requester.ProviderAPIRequester): """ModelScope ChatCompletion API 请求器""" client: openai.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml index a641a672..a926d889 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./modelscopechatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml index f3ae73c8..52f7bcda 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./moonshotchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index 2ea4bb7d..1456515f 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -17,7 +17,7 @@ from ....core import entities as core_entities REQUESTER_NAME: str = 'ollama-chat' -class OllamaChatCompletions(requester.LLMAPIRequester): +class OllamaChatCompletions(requester.ProviderAPIRequester): """Ollama平台 ChatCompletion API请求器""" client: ollama.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/ollamachat.yaml b/pkg/provider/modelmgr/requesters/ollamachat.yaml index 01435775..f4c4bf5a 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.yaml +++ b/pkg/provider/modelmgr/requesters/ollamachat.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./ollamachat.py diff --git a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml index 2ecee6cc..ea35bce6 100644 --- a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./openrouterchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml index 9f201aa9..a5a3421c 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./ppiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml index 19b3dcc3..3872cb6f 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./siliconflowchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml index 402f04e7..c711ef2d 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./volcarkchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml index 29db4eb3..2769a402 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./xaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml index a05184ef..34539d95 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./zhipuaichatcmpl.py diff --git a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx index e21317d6..ef9c6f45 100644 --- a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx +++ b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx @@ -47,6 +47,7 @@ export const sidebarConfigList = [ zh_Hans: 'https://docs.langbot.app/zh/deploy/models/readme.html', }, }), + new SidebarChildVO({ id: 'pipelines', name: t('pipelines.title'), diff --git a/web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts b/web/src/app/home/models/component/ChooseRequesterEntity.ts similarity index 100% rename from web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts rename to web/src/app/home/models/component/ChooseRequesterEntity.ts diff --git a/web/src/app/home/models/component/ICreateEmbeddingField.ts b/web/src/app/home/models/component/ICreateEmbeddingField.ts new file mode 100644 index 00000000..ea198f3f --- /dev/null +++ b/web/src/app/home/models/component/ICreateEmbeddingField.ts @@ -0,0 +1,7 @@ +export interface ICreateEmbeddingField { + name: string; + model_provider: string; + url: string; + api_key: string; + extra_args?: string[]; +} diff --git a/web/src/app/home/models/ICreateLLMField.ts b/web/src/app/home/models/component/ICreateLLMField.ts similarity index 100% rename from web/src/app/home/models/ICreateLLMField.ts rename to web/src/app/home/models/component/ICreateLLMField.ts diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css new file mode 100644 index 00000000..9c6c54f7 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css @@ -0,0 +1,97 @@ +.cardContainer { + width: 100%; + height: 10rem; + background-color: #fff; + border-radius: 10px; + box-shadow: 0px 2px 2px 0 rgba(0, 0, 0, 0.2); + padding: 1.2rem; + cursor: pointer; +} + +.cardContainer:hover { + box-shadow: 0px 2px 8px 0 rgba(0, 0, 0, 0.1); +} + +.iconBasicInfoContainer { + width: 100%; + height: 100%; + display: flex; + flex-direction: row; + gap: 0.8rem; + user-select: none; +} + +.iconImage { + width: 3.8rem; + height: 3.8rem; + margin: 0.2rem; + border-radius: 50%; +} + +.basicInfoContainer { + display: flex; + flex-direction: column; + gap: 0.2rem; + min-width: 0; + width: 100%; +} + +.basicInfoText { + font-size: 1.4rem; + font-weight: bold; +} + +.providerContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; +} + +.providerIcon { + width: 1.2rem; + height: 1.2rem; + margin-top: 0.2rem; + color: #626262; +} + +.providerLabel { + font-size: 1.2rem; + font-weight: 600; + color: #626262; +} + +.baseURLContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; + width: calc(100% - 3rem); +} + +.baseURLIcon { + width: 1.2rem; + height: 1.2rem; + color: #626262; +} + +.baseURLText { + font-size: 1rem; + width: 100%; + color: #626262; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + max-width: 100%; +} + +.bigText { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + font-size: 1.4rem; + font-weight: bold; + max-width: 100%; +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx new file mode 100644 index 00000000..e3dfaf80 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx @@ -0,0 +1,53 @@ +import styles from './EmbeddingCard.module.css'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; + +export default function EmbeddingCard({ cardVO }: { cardVO: EmbeddingCardVO }) { + return ( +
+
+ icon + +
+ {/* 名称 */} +
+ {cardVO.name} +
+ {/* 厂商 */} +
+ + + + + {cardVO.providerLabel} + +
+ {/* baseURL */} +
+ + + + {cardVO.baseURL} +
+
+
+
+ ); +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts new file mode 100644 index 00000000..f6d960f6 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts @@ -0,0 +1,23 @@ +export interface IEmbeddingCardVO { + id: string; + iconURL: string; + name: string; + providerLabel: string; + baseURL: string; +} + +export class EmbeddingCardVO implements IEmbeddingCardVO { + id: string; + iconURL: string; + providerLabel: string; + name: string; + baseURL: string; + + constructor(props: IEmbeddingCardVO) { + this.id = props.id; + this.iconURL = props.iconURL; + this.providerLabel = props.providerLabel; + this.name = props.name; + this.baseURL = props.baseURL; + } +} diff --git a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx new file mode 100644 index 00000000..4658a22f --- /dev/null +++ b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx @@ -0,0 +1,563 @@ +import { ICreateEmbeddingField } from '@/app/home/models/component/ICreateEmbeddingField'; +import { useEffect, useState } from 'react'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { EmbeddingModel } from '@/app/infra/entities/api'; +import { UUID } from 'uuidjs'; + +import { zodResolver } from '@hookform/resolvers/zod'; +import { useForm } from 'react-hook-form'; +import { z } from 'zod'; +import { useTranslation } from 'react-i18next'; + +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogFooter, +} from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { toast } from 'sonner'; +import { i18nObj } from '@/i18n/I18nProvider'; + +const getExtraArgSchema = (t: (key: string) => string) => + z + .object({ + key: z.string().min(1, { message: t('models.keyNameRequired') }), + type: z.enum(['string', 'number', 'boolean']), + value: z.string(), + }) + .superRefine((data, ctx) => { + if (data.type === 'number' && isNaN(Number(data.value))) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeValidNumber'), + path: ['value'], + }); + } + if ( + data.type === 'boolean' && + data.value !== 'true' && + data.value !== 'false' + ) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeTrueOrFalse'), + path: ['value'], + }); + } + }); + +const getFormSchema = (t: (key: string) => string) => + z.object({ + name: z.string().min(1, { message: t('models.modelNameRequired') }), + model_provider: z + .string() + .min(1, { message: t('models.modelProviderRequired') }), + url: z.string().min(1, { message: t('models.requestURLRequired') }), + api_key: z.string().min(1, { message: t('models.apiKeyRequired') }), + extra_args: z.array(getExtraArgSchema(t)).optional(), + }); + +export default function EmbeddingForm({ + editMode, + initEmbeddingId, + onFormSubmit, + onFormCancel, + onEmbeddingDeleted, +}: { + editMode: boolean; + initEmbeddingId?: string; + onFormSubmit: () => void; + onFormCancel: () => void; + onEmbeddingDeleted: () => void; +}) { + const { t } = useTranslation(); + const formSchema = getFormSchema(t); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + name: '', + model_provider: '', + url: '', + api_key: 'sk-xxxxx', + extra_args: [], + }, + }); + + const [extraArgs, setExtraArgs] = useState< + { key: string; type: 'string' | 'number' | 'boolean'; value: string }[] + >([]); + + const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false); + const [requesterNameList, setRequesterNameList] = useState< + IChooseRequesterEntity[] + >([]); + const [requesterDefaultURLList, setRequesterDefaultURLList] = useState< + string[] + >([]); + const [modelTesting, setModelTesting] = useState(false); + + useEffect(() => { + initEmbeddingModelFormComponent().then(() => { + if (editMode && initEmbeddingId) { + getEmbeddingConfig(initEmbeddingId).then((val) => { + form.setValue('name', val.name); + form.setValue('model_provider', val.model_provider); + // setCurrentModelProvider(val.model_provider); + form.setValue('url', val.url); + form.setValue('api_key', val.api_key); + if (val.extra_args) { + const args = val.extra_args.map((arg) => { + const [key, value] = arg.split(':'); + let type: 'string' | 'number' | 'boolean' = 'string'; + if (!isNaN(Number(value))) { + type = 'number'; + } else if (value === 'true' || value === 'false') { + type = 'boolean'; + } + return { + key, + type, + value, + }; + }); + setExtraArgs(args); + form.setValue('extra_args', args); + } + }); + } else { + form.reset(); + } + }); + }, []); + + const addExtraArg = () => { + setExtraArgs([...extraArgs, { key: '', type: 'string', value: '' }]); + }; + + const updateExtraArg = ( + index: number, + field: 'key' | 'type' | 'value', + value: string, + ) => { + const newArgs = [...extraArgs]; + newArgs[index] = { + ...newArgs[index], + [field]: value, + }; + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + const removeExtraArg = (index: number) => { + const newArgs = extraArgs.filter((_, i) => i !== index); + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + async function initEmbeddingModelFormComponent() { + const requesterNameList = + await httpClient.getProviderRequesters('text-embedding'); + setRequesterNameList( + requesterNameList.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }), + ); + setRequesterDefaultURLList( + requesterNameList.requesters.map((item) => { + const config = item.spec.config; + for (let i = 0; i < config.length; i++) { + if (config[i].name == 'base_url') { + return config[i].default?.toString() || ''; + } + } + return ''; + }), + ); + } + + async function getEmbeddingConfig( + id: string, + ): Promise { + const embeddingModel = await httpClient.getProviderEmbeddingModel(id); + + const fakeExtraArgs = []; + const extraArgs = embeddingModel.model.extra_args as Record; + for (const key in extraArgs) { + fakeExtraArgs.push(`${key}:${extraArgs[key]}`); + } + return { + name: embeddingModel.model.name, + model_provider: embeddingModel.model.requester, + url: embeddingModel.model.requester_config?.base_url, + api_key: embeddingModel.model.api_keys[0], + extra_args: fakeExtraArgs, + }; + } + + function handleFormSubmit(value: z.infer) { + const extraArgsObj: Record = {}; + value.extra_args?.forEach( + (arg: { key: string; type: string; value: string }) => { + if (arg.type === 'number') { + extraArgsObj[arg.key] = Number(arg.value); + } else if (arg.type === 'boolean') { + extraArgsObj[arg.key] = arg.value === 'true'; + } else { + extraArgsObj[arg.key] = arg.value; + } + }, + ); + + const embeddingModel: EmbeddingModel = { + uuid: editMode ? initEmbeddingId || '' : UUID.generate(), + name: value.name, + description: '', + requester: value.model_provider, + requester_config: { + base_url: value.url, + timeout: 120, + }, + extra_args: extraArgsObj, + api_keys: [value.api_key], + }; + + if (editMode) { + onSaveEdit(embeddingModel).then(() => { + form.reset(); + }); + } else { + onCreateEmbedding(embeddingModel).then(() => { + form.reset(); + }); + } + } + + async function onCreateEmbedding(embeddingModel: EmbeddingModel) { + try { + await httpClient.createProviderEmbeddingModel(embeddingModel); + onFormSubmit(); + toast.success(t('models.createSuccess')); + } catch (err) { + toast.error(t('models.createError') + (err as Error).message); + } + } + + async function onSaveEdit(embeddingModel: EmbeddingModel) { + try { + await httpClient.updateProviderEmbeddingModel( + initEmbeddingId || '', + embeddingModel, + ); + onFormSubmit(); + toast.success(t('models.saveSuccess')); + } catch (err) { + toast.error(t('models.saveError') + (err as Error).message); + } + } + + function deleteModel() { + if (initEmbeddingId) { + httpClient + .deleteProviderEmbeddingModel(initEmbeddingId) + .then(() => { + onEmbeddingDeleted(); + toast.success(t('models.deleteSuccess')); + }) + .catch((err) => { + toast.error(t('models.deleteError') + err.message); + }); + } + } + + function testEmbeddingModelInForm() { + setModelTesting(true); + httpClient + .testEmbeddingModel('_', { + uuid: '', + name: form.getValues('name'), + description: '', + requester: form.getValues('model_provider'), + requester_config: { + base_url: form.getValues('url'), + timeout: 120, + }, + api_keys: [form.getValues('api_key')], + }) + .then((res) => { + console.log(res); + toast.success(t('models.testSuccess')); + }) + .catch(() => { + toast.error(t('models.testError')); + }) + .finally(() => { + setModelTesting(false); + }); + } + + return ( +
+ + + + {t('common.confirmDelete')} + + + {t('models.deleteConfirmation')} + + + + + + + + +
+ +
+ ( + + + {t('models.modelName')} + * + + + + + + + {t('models.modelProviderDescription')} + + + )} + /> + + ( + + + {t('models.modelProvider')} + * + + + + + + + )} + /> + + ( + + + {t('models.requestURL')} + * + + + + + + + )} + /> + + ( + + + {t('models.apiKey')} + * + + + + + + + )} + /> + + + {t('models.extraParameters')} +
+ {extraArgs.map((arg, index) => ( +
+ + updateExtraArg(index, 'key', e.target.value) + } + /> + + + updateExtraArg(index, 'value', e.target.value) + } + /> + +
+ ))} + +
+ + {t('embedding.extraParametersDescription')} + + +
+
+ + {editMode && ( + + )} + + + + + + + +
+ +
+ ); +} diff --git a/web/src/app/home/models/component/llm-form/LLMForm.tsx b/web/src/app/home/models/component/llm-form/LLMForm.tsx index f483f183..73cc32fe 100644 --- a/web/src/app/home/models/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/models/component/llm-form/LLMForm.tsx @@ -1,6 +1,6 @@ -import { ICreateLLMField } from '@/app/home/models/ICreateLLMField'; +import { ICreateLLMField } from '@/app/home/models/component/ICreateLLMField'; import { useEffect, useState } from 'react'; -import { IChooseRequesterEntity } from '@/app/home/models/component/llm-form/ChooseRequesterEntity'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; import { UUID } from 'uuidjs'; @@ -197,7 +197,7 @@ export default function LLMForm({ }; async function initLLMModelFormComponent() { - const requesterNameList = await httpClient.getProviderRequesters(); + const requesterNameList = await httpClient.getProviderRequesters('llm'); setRequesterNameList( requesterNameList.requesters.map((item) => { return { @@ -596,7 +596,7 @@ export default function LLMForm({ - {t('models.extraParametersDescription')} + {t('llm.extraParametersDescription')} diff --git a/web/src/app/home/models/page.tsx b/web/src/app/home/models/page.tsx index 3ccec486..2f936753 100644 --- a/web/src/app/home/models/page.tsx +++ b/web/src/app/home/models/page.tsx @@ -8,6 +8,7 @@ import LLMForm from '@/app/home/models/component/llm-form/LLMForm'; import CreateCardComponent from '@/app/infra/basic-component/create-card-component/CreateCardComponent'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; import { Dialog, DialogContent, @@ -17,6 +18,9 @@ import { import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; import { i18nObj } from '@/i18n/I18nProvider'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; +import EmbeddingCard from '@/app/home/models/component/embedding-card/EmbeddingCard'; +import EmbeddingForm from '@/app/home/models/component/embedding-form/EmbeddingForm'; export default function LLMConfigPage() { const { t } = useTranslation(); @@ -24,13 +28,21 @@ export default function LLMConfigPage() { const [modalOpen, setModalOpen] = useState(false); const [isEditForm, setIsEditForm] = useState(false); const [nowSelectedLLM, setNowSelectedLLM] = useState(null); + const [embeddingCardList, setEmbeddingCardList] = useState( + [], + ); + const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false); + const [isEditEmbeddingForm, setIsEditEmbeddingForm] = useState(false); + const [nowSelectedEmbedding, setNowSelectedEmbedding] = + useState(null); useEffect(() => { getLLMModelList(); + getEmbeddingModelList(); }, []); async function getLLMModelList() { - const requesterNameListResp = await httpClient.getProviderRequesters(); + const requesterNameListResp = await httpClient.getProviderRequesters('llm'); const requesterNameList = requesterNameListResp.requesters.map((item) => { return { label: i18nObj(item.label), @@ -74,6 +86,55 @@ export default function LLMConfigPage() { setNowSelectedLLM(null); setModalOpen(true); } + function selectEmbedding(cardVO: EmbeddingCardVO) { + setIsEditEmbeddingForm(true); + setNowSelectedEmbedding(cardVO); + setEmbeddingModalOpen(true); + } + + function handleCreateEmbeddingModelClick() { + setIsEditEmbeddingForm(false); + setNowSelectedEmbedding(null); + setEmbeddingModalOpen(true); + } + async function getEmbeddingModelList() { + const requesterNameListResp = + await httpClient.getProviderRequesters('text-embedding'); + const requesterNameList = requesterNameListResp.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }); + + httpClient + .getProviderEmbeddingModels() + .then((resp) => { + const embeddingModelList: EmbeddingCardVO[] = resp.models.map( + (model: { + uuid: string; + requester: string; + name: string; + requester_config?: { base_url?: string }; + }) => { + return new EmbeddingCardVO({ + id: model.uuid, + iconURL: httpClient.getProviderRequesterIconURL(model.requester), + name: model.name, + providerLabel: + requesterNameList.find((item) => item.value === model.requester) + ?.label || model.requester.substring(0, 10), + baseURL: model.requester_config?.base_url || '', + }); + }, + ); + setEmbeddingCardList(embeddingModelList); + }) + .catch((err) => { + console.error('get Embedding model list error', err); + toast.error(t('embedding.getModelListError') + err.message); + }); + } return (
@@ -101,26 +162,108 @@ export default function LLMConfigPage() { /> -
- - {cardList.map((cardVO) => { - return ( -
{ - selectLLM(cardVO); - }} - > - + + + + + {isEditEmbeddingForm + ? t('embedding.editModel') + : t('embedding.createModel')} + + + { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + onFormCancel={() => { + setEmbeddingModalOpen(false); + }} + onEmbeddingDeleted={() => { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + /> + + + + +
+
+ + + {t('llm.llmModels')} + + + {t('embedding.embeddingModels')} + + +
+ +
+

{t('llm.description')}

- ); - })} -
+ + +
+

+ {t('embedding.description')} +

+
+
+
+ + +
+ + {cardList.map((cardVO) => { + return ( +
{ + selectLLM(cardVO); + }} + > + +
+ ); + })} +
+
+ + +
+ + {embeddingCardList.map((cardVO) => { + return ( +
{ + selectEmbedding(cardVO); + }} + > + +
+ ); + })} +
+
+
); } diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index d86a8be0..53ddf1dd 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -55,6 +55,29 @@ export interface LLMModel { // updated_at: string; } +export interface ApiRespProviderEmbeddingModels { + models: EmbeddingModel[]; +} + +export interface ApiRespProviderEmbeddingModel { + model: EmbeddingModel; +} + +export interface EmbeddingModel { + name: string; + description: string; + uuid: string; + requester: string; + requester_config: { + base_url: string; + timeout: number; + }; + extra_args?: object; + api_keys: string[]; + // created_at: string; + // updated_at: string; +} + export interface ApiRespPipelines { pipelines: Pipeline[]; } diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 5193703b..1fd335d9 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -10,6 +10,9 @@ import { ApiRespProviderLLMModels, ApiRespProviderLLMModel, LLMModel, + ApiRespProviderEmbeddingModels, + ApiRespProviderEmbeddingModel, + EmbeddingModel, ApiRespPipelines, Pipeline, ApiRespPlatformAdapters, @@ -226,8 +229,10 @@ class HttpClient { // real api request implementation // ============ Provider API ============ - public getProviderRequesters(): Promise { - return this.get('/api/v1/provider/requesters'); + public getProviderRequesters( + model_type: string, + ): Promise { + return this.get('/api/v1/provider/requesters', { type: model_type }); } public getProviderRequester(name: string): Promise { @@ -275,6 +280,39 @@ class HttpClient { return this.post(`/api/v1/provider/models/llm/${uuid}/test`, model); } + // ============ Provider Model Embedding ============ + public getProviderEmbeddingModels(): Promise { + return this.get('/api/v1/provider/models/embedding'); + } + + public getProviderEmbeddingModel( + uuid: string, + ): Promise { + return this.get(`/api/v1/provider/models/embedding/${uuid}`); + } + + public createProviderEmbeddingModel(model: EmbeddingModel): Promise { + return this.post('/api/v1/provider/models/embedding', model); + } + + public deleteProviderEmbeddingModel(uuid: string): Promise { + return this.delete(`/api/v1/provider/models/embedding/${uuid}`); + } + + public updateProviderEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.put(`/api/v1/provider/models/embedding/${uuid}`, model); + } + + public testEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.post(`/api/v1/provider/models/embedding/${uuid}/test`, model); + } + // ============ Pipeline API ============ public getGeneralPipelineMetadata(): Promise { // as designed, this method will be deprecated, and only for developer to check the prefered config schema diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 1975a521..d0df9841 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -86,14 +86,13 @@ const enUS = { string: 'String', number: 'Number', boolean: 'Boolean', - extraParametersDescription: - 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', selectModelProvider: 'Select Model Provider', modelProviderDescription: 'Please fill in the model name provided by the supplier', selectModel: 'Select Model', testSuccess: 'Test successful', testError: 'Test failed, please check your model configuration', + llmModels: 'LLM Models', }, bots: { title: 'Bots', @@ -259,6 +258,21 @@ const enUS = { 'Password reset failed, please check your email and recovery key', backToLogin: 'Back to Login', }, + embedding: { + description: 'Manage Embedding models for text vectorization', + createModel: 'Create Embedding Model', + editModel: 'Edit Embedding Model', + getModelListError: 'Failed to get Embedding model list: ', + embeddingModels: 'Embedding', + extraParametersDescription: + 'Will be attached to the request body, such as encoding_format, dimensions, etc.', + }, + llm: { + description: 'Manage LLM models for conversation generation', + llmModels: 'LLM', + extraParametersDescription: + 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', + }, }; export default enUS; diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 2ded8236..96acc0e6 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -87,13 +87,12 @@ const zhHans = { string: '字符串', number: '数字', boolean: '布尔值', - extraParametersDescription: - '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', selectModelProvider: '选择模型供应商', modelProviderDescription: '请填写供应商向您提供的模型名称', selectModel: '请选择模型', testSuccess: '测试成功', testError: '测试失败,请检查模型配置', + llmModels: '对话模型', }, bots: { title: '机器人', @@ -251,6 +250,21 @@ const zhHans = { resetFailed: '密码重置失败,请检查邮箱和恢复密钥是否正确', backToLogin: '返回登录', }, + embedding: { + description: '管理嵌入模型,用于向量化文本', + createModel: '创建嵌入模型', + editModel: '编辑嵌入模型', + getModelListError: '获取嵌入模型列表失败:', + embeddingModels: '嵌入模型', + extraParametersDescription: + '将在请求时附加到请求体中,如 encoding_format, dimensions 等', + }, + llm: { + llmModels: '对话模型', + description: '管理 LLM 模型,用于对话消息生成', + extraParametersDescription: + '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', + }, }; export default zhHans; From 157ffdc34c96d3faf7bc6b9d22d362e55834c678 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Tue, 10 Jun 2025 08:34:53 +0800 Subject: [PATCH 2/7] feat: add knowledge page --- web/src/app/home/knowledge/page.tsx | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 web/src/app/home/knowledge/page.tsx diff --git a/web/src/app/home/knowledge/page.tsx b/web/src/app/home/knowledge/page.tsx new file mode 100644 index 00000000..9707a8ee --- /dev/null +++ b/web/src/app/home/knowledge/page.tsx @@ -0,0 +1,5 @@ +'use client'; + +export default function KnowledgePage() { + return
KnowledgePage
; +} From 348f6d9eaa0759638a37e752502f361c5e848f75 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Wed, 11 Jun 2025 20:24:42 +0800 Subject: [PATCH 3/7] feat: add api for uploading files --- pkg/api/http/controller/groups/files.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pkg/api/http/controller/groups/files.py b/pkg/api/http/controller/groups/files.py index 0a8b2210..d08cbd71 100644 --- a/pkg/api/http/controller/groups/files.py +++ b/pkg/api/http/controller/groups/files.py @@ -2,6 +2,10 @@ from __future__ import annotations import quart import mimetypes +import uuid +import asyncio + +import quart.datastructures from .. import group @@ -20,3 +24,22 @@ class FilesRouterGroup(group.RouterGroup): mime_type = 'image/jpeg' return quart.Response(image_bytes, mimetype=mime_type) + + @self.route('/documents', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> quart.Response: + request = quart.request + # get file bytes from 'file' + file = (await request.files)['file'] + assert isinstance(file, quart.datastructures.FileStorage) + + file_bytes = await asyncio.to_thread(file.stream.read) + extension = file.filename.split('.')[-1] + + file_key = str(uuid.uuid4()) + '.' + extension + # save file to storage + await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes) + return self.success( + data={ + 'file_id': file_key, + } + ) From 4bcc06c9559f08f8e84b3c95a94eb4258da22688 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Wed, 25 Jun 2025 14:32:53 +0800 Subject: [PATCH 4/7] kb --- .../http/controller/groups/knowledge_base.py | 83 +++++ pkg/core/app.py | 7 + pkg/rag/knowledge/RAG_Manager.py | 283 +++++++++++++++++ pkg/rag/knowledge/services/__init__.py | 0 pkg/rag/knowledge/services/base_service.py | 26 ++ pkg/rag/knowledge/services/chroma_manager.py | 65 ++++ pkg/rag/knowledge/services/chunker.py | 63 ++++ pkg/rag/knowledge/services/database.py | 57 ++++ pkg/rag/knowledge/services/embedder.py | 93 ++++++ .../knowledge/services/embedding_models.py | 223 ++++++++++++++ pkg/rag/knowledge/services/parser.py | 288 ++++++++++++++++++ pkg/rag/knowledge/services/retriever.py | 106 +++++++ pkg/rag/knowledge/utils/crawler.py | 0 pyproject.toml | 11 + 14 files changed, 1305 insertions(+) create mode 100644 pkg/api/http/controller/groups/knowledge_base.py create mode 100644 pkg/rag/knowledge/RAG_Manager.py create mode 100644 pkg/rag/knowledge/services/__init__.py create mode 100644 pkg/rag/knowledge/services/base_service.py create mode 100644 pkg/rag/knowledge/services/chroma_manager.py create mode 100644 pkg/rag/knowledge/services/chunker.py create mode 100644 pkg/rag/knowledge/services/database.py create mode 100644 pkg/rag/knowledge/services/embedder.py create mode 100644 pkg/rag/knowledge/services/embedding_models.py create mode 100644 pkg/rag/knowledge/services/parser.py create mode 100644 pkg/rag/knowledge/services/retriever.py create mode 100644 pkg/rag/knowledge/utils/crawler.py diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge_base.py new file mode 100644 index 00000000..c819397a --- /dev/null +++ b/pkg/api/http/controller/groups/knowledge_base.py @@ -0,0 +1,83 @@ +import quart +from __future__ import annotations +from .. import group + +@group.group_class('knowledge_base', '/api/v1/knowledge/bases') +class KnowledgeBaseRouterGroup(group.RouterGroup): + + # 定义成功方法 + def success(self, code=0, data=None, msg: str = 'ok') -> quart.Response: + return quart.jsonify({ + "code": code, + "data": data or {}, + "msg": msg + }) + + + + async def initialize(self) -> None: + rag = self.ap.knowledge_base_service.RAG_Manager() + + @self.route('', methods=['POST', 'GET']) + async def _() -> str: + + if quart.request.method == 'GET': + knowledge_bases = await rag.get_all_knowledge_bases() + bases_list = [ + { + "uuid": kb.id, + "name": kb.name, + "description": kb.description, + } for kb in knowledge_bases + ] + return self.success(code=0, + data={'bases': bases_list}, + msg='ok') + + json_data = await quart.request.json + knowledge_base_uuid = await rag.create_knowledge_base( + json_data.get('name'), + json_data.get('description') + ) + return self.success() + + + @self.route('/', methods=['GET']) + async def _(knowledge_base_uuid: str) -> str: + if quart.request.method == 'GET': + knowledge_base = await rag.get_knowledge_base_by_id(knowledge_base_uuid) + + if knowledge_base is None: + return self.http_status(404, -1, 'knowledge base not found') + + return self.success( + code=0, + data={ + "name": knowledge_base.name, + "description": knowledge_base.description, + "uuid": knowledge_base.id + }, + msg='ok' + ) + + @self.route('//files', methods=['GET']) + async def _(knowledge_base_uuid: str) -> str: + if quart.request.method == 'GET': + files = await rag.get_files_by_knowledge_base(knowledge_base_uuid) + return self.success(code=0,data=[{ + "id": file.id, + "file_name": file.file_name, + "status": file.status + } for file in files],msg='ok') + + # delete specific file in knowledge base + @self.route('//files/', methods=['DELETE']) + async def _(knowledge_base_uuid: str, file_id: str) -> str: + await rag.delete_data_by_file_id(file_id) + return self.success(code=0, msg='ok') + + # delete specific kb + @self.route('/', methods=['DELETE']) + async def _(knowledge_base_uuid: str) -> str: + await rag.delete_kb_by_id(knowledge_base_uuid) + return self.success(code=0, msg='ok') diff --git a/pkg/core/app.py b/pkg/core/app.py index 318cddcb..d8824466 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -27,6 +27,10 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities +from ...pkg.rag.knowledge import RAG_Manager + + + class Application: @@ -99,6 +103,8 @@ class Application: storage_mgr: storagemgr.StorageMgr = None + knowledge_base_service: RAG_Manager = None + # ========= HTTP Services ========= user_service: user_service.UserService = None @@ -111,6 +117,7 @@ class Application: bot_service: bot_service.BotService = None + def __init__(self): pass diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py new file mode 100644 index 00000000..e172c132 --- /dev/null +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -0,0 +1,283 @@ +# RAG_Manager class (main class, adjust imports as needed) +import logging +import os +import asyncio +from services.parser import FileParser +from services.chunker import Chunker +from services.embedder import Embedder +from services.retriever import Retriever +from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly +from services.embedding_models import EmbeddingModelFactory +from services.chroma_manager import ChromaIndexManager +from ...core import app + +class RAG_Manager: + def __init__(self, logger: logging.Logger = None): + self.logger = logger or logging.getLogger(__name__) + self.embedding_model_type = None + self.embedding_model_name = None + self.chroma_manager = None + self.parser = None + self.chunker = None + self.embedder = None + self.retriever = None + self.ap = app.Application + + async def initialize_system(self): + await asyncio.to_thread(create_db_and_tables) + + async def create_model(self, embedding_model_type: str, + embedding_model_name: str): + self.embedding_model_type = embedding_model_type + self.embedding_model_name = embedding_model_name + + try: + model = EmbeddingModelFactory.create_model( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name + ) + self.logger.info(f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}") + except Exception as e: + self.logger.critical(f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}") + raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.") + + self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}") + + self.parser = FileParser() + self.chunker = Chunker() + # Pass chroma_manager to Embedder and Retriever + self.embedder = Embedder( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager # Inject dependency + ) + self.retriever = Retriever( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager # Inject dependency + ) + + + async def create_knowledge_base(self, kb_name: str, kb_description: str): + """ + Creates a new knowledge base with the given name and description. + If a knowledge base with the same name already exists, it returns that one. + """ + try: + def _get_kb_sync(name): + session = SessionLocal() + try: + return session.query(KnowledgeBase).filter_by(name=name).first() + finally: + session.close() + + kb = await asyncio.to_thread(_get_kb_sync, kb_name) + + if not kb: + def _add_kb_sync(): + session = SessionLocal() + try: + new_kb = KnowledgeBase(name=kb_name, description=kb_description) + session.add(new_kb) + session.commit() + session.refresh(new_kb) + return new_kb + finally: + session.close() + kb = await asyncio.to_thread(_add_kb_sync) + except Exception as e: + self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) + raise + except Exception as e: + self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) + raise + + async def get_all_knowledge_bases(self): + """ + Retrieves all knowledge bases from the database. + """ + try: + def _get_all_kbs_sync(): + session = SessionLocal() + try: + return session.query(KnowledgeBase).all() + finally: + session.close() + + kbs = await asyncio.to_thread(_get_all_kbs_sync) + return kbs + except Exception as e: + self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True) + return [] + + async def get_knowledge_base_by_id(self, kb_id: int): + """ + Retrieves a knowledge base by its ID. + """ + try: + def _get_kb_sync(kb_id): + session = SessionLocal() + try: + return session.query(KnowledgeBase).filter_by(id=kb_id).first() + finally: + session.close() + + kb = await asyncio.to_thread(_get_kb_sync, kb_id) + return kb + except Exception as e: + self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + return None + + async def get_files_by_knowledge_base(self, kb_id: int): + try: + def _get_files_sync(kb_id): + session = SessionLocal() + try: + return session.query(File).filter_by(kb_id=kb_id).all() + finally: + session.close() + + files = await asyncio.to_thread(_get_files_sync, kb_id) + return files + except Exception as e: + self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True) + return [] + + + async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"): + self.logger.info(f"Starting data storage process for file: {file_path}") + try: + def _get_kb_sync(name): + session = SessionLocal() + try: + return session.query(KnowledgeBase).filter_by(name=name).first() + finally: + session.close() + + kb = await asyncio.to_thread(_get_kb_sync, kb_name) + + if not kb: + self.logger.info(f"Knowledge Base '{kb_name}' not found. Creating a new one.") + def _add_kb_sync(): + session = SessionLocal() + try: + new_kb = KnowledgeBase(name=kb_name, description=kb_description) + session.add(new_kb) + session.commit() + session.refresh(new_kb) + return new_kb + finally: + session.close() + kb = await asyncio.to_thread(_add_kb_sync) + self.logger.info(f"Created Knowledge Base: {kb.name} (ID: {kb.id})") + + def _add_file_sync(kb_id, file_name, path, file_type): + session = SessionLocal() + try: + file = File(kb_id=kb_id, file_name=file_name, path=path, file_type=file_type) + session.add(file) + session.commit() + session.refresh(file) + return file + finally: + session.close() + + file_obj = await asyncio.to_thread(_add_file_sync, kb.id, os.path.basename(file_path), file_path, file_type) + self.logger.info(f"Added file entry: {file_obj.file_name} (ID: {file_obj.id})") + + text = await self.parser.parse(file_path) + if not text: + self.logger.warning(f"File {file_path} parsed to empty content. Skipping chunking and embedding.") + # You might want to delete the file_obj from the DB here if it's empty. + session = SessionLocal() + try: + session.delete(file_obj) + session.commit() + except Exception as del_e: + self.logger.error(f"Failed to delete empty file_obj {file_obj.id}: {del_e}") + finally: + session.close() + return + + chunks_texts = await self.chunker.chunk(text) + self.logger.info(f"Chunked into {len(chunks_texts)} pieces.") + + # embed_and_store now handles both DB chunk saving and Chroma embedding + await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) + + self.logger.info(f"Data storage process completed for file: {file_path}") + + except Exception as e: + self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True) + # Consider cleaning up partially stored data if an error occurs. + return + + async def retrieve_data(self, query: str): + self.logger.info(f"Starting data retrieval process for query: '{query}'") + try: + retrieved_chunks = await self.retriever.retrieve(query) + self.logger.info(f"Successfully retrieved {len(retrieved_chunks)} chunks for query.") + return retrieved_chunks + except Exception as e: + self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) + return [] + + async def delete_data_by_file_id(self, file_id: int): + """ + Deletes data associated with a specific file_id from both the relational DB and Chroma. + """ + self.logger.info(f"Starting data deletion process for file_id: {file_id}") + session = SessionLocal() + try: + # 1. Delete from Chroma + await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) + + # 2. Delete chunks from relational DB + chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() + for chunk in chunks_to_delete: + session.delete(chunk) + self.logger.info(f"Deleted {len(chunks_to_delete)} chunks from relational DB for file_id: {file_id}.") + + # 3. Delete file entry from relational DB + file_to_delete = session.query(File).filter_by(id=file_id).first() + if file_to_delete: + session.delete(file_to_delete) + self.logger.info(f"Deleted file entry {file_id} from relational DB.") + else: + self.logger.warning(f"File entry {file_id} not found in relational DB.") + + session.commit() + self.logger.info(f"Data deletion completed for file_id: {file_id}.") + except Exception as e: + session.rollback() + self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True) + finally: + session.close() + + async def delete_kb_by_id(self, kb_id: int): + """ + Deletes a knowledge base and all associated files and chunks. + """ + self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}") + session = SessionLocal() + try: + # 1. Get the knowledge base + kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if not kb: + self.logger.warning(f"Knowledge Base with ID {kb_id} not found.") + return + + # 2. Delete all files associated with this knowledge base + files_to_delete = session.query(File).filter_by(kb_id=kb.id).all() + for file in files_to_delete: + await self.delete_data_by_file_id(file.id) + + # 3. Delete the knowledge base itself + session.delete(kb) + session.commit() + self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}") + except Exception as e: + session.rollback() + self.logger.error(f"Error deleting knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + finally: + session.close() diff --git a/pkg/rag/knowledge/services/__init__.py b/pkg/rag/knowledge/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/rag/knowledge/services/base_service.py b/pkg/rag/knowledge/services/base_service.py new file mode 100644 index 00000000..0298226a --- /dev/null +++ b/pkg/rag/knowledge/services/base_service.py @@ -0,0 +1,26 @@ +# 封装异步操作 +import asyncio +import logging +from services.database import SessionLocal # 导入 SessionLocal 工厂函数 + +class BaseService: + def __init__(self): + self.logger = logging.getLogger(self.__class__.__name__) + self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数 + + async def _run_sync(self, func, *args, **kwargs): + """ + 在单独的线程中运行同步函数。 + 如果第一个参数是 session,则在 to_thread 中获取新的 session。 + """ + # 如果函数需要数据库会话作为第一个参数,我们在这里获取它 + if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头 + session = await asyncio.to_thread(self.db_session_factory) + try: + result = await asyncio.to_thread(func, session, *args, **kwargs) + return result + finally: + session.close() + else: + # 否则,直接运行同步函数 + return await asyncio.to_thread(func, *args, **kwargs) \ No newline at end of file diff --git a/pkg/rag/knowledge/services/chroma_manager.py b/pkg/rag/knowledge/services/chroma_manager.py new file mode 100644 index 00000000..6a469168 --- /dev/null +++ b/pkg/rag/knowledge/services/chroma_manager.py @@ -0,0 +1,65 @@ +# services/chroma_manager.py +import numpy as np +import logging +from chromadb import PersistentClient +import os + +logger = logging.getLogger(__name__) + +class ChromaIndexManager: + def __init__(self, collection_name: str = "default_collection"): + self.logger = logging.getLogger(self.__class__.__name__) + chroma_data_path = "./chroma_data" + os.makedirs(chroma_data_path, exist_ok=True) + self.client = PersistentClient(path=chroma_data_path) + self._collection_name = collection_name + self._collection = None + + self.logger.info(f"ChromaIndexManager initialized. Collection name: {self._collection_name}") + + @property + def collection(self): + if self._collection is None: + self._collection = self.client.get_or_create_collection(name=self._collection_name) + self.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.") + return self._collection + + def add_embeddings_sync(self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]): + if embeddings.shape[0] != len(chunk_ids) or embeddings.shape[0] != len(file_ids) or embeddings.shape[0] != len(documents): + raise ValueError("Embedding, file_id, chunk_id, and document count mismatch.") + + chroma_ids = [f"{file_id}_{chunk_id}" for file_id, chunk_id in zip(file_ids, chunk_ids)] + metadatas = [{"file_id": fid, "chunk_id": cid} for fid, cid in zip(file_ids, chunk_ids)] + + self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") + self.collection.add( + embeddings=embeddings.tolist(), + ids=chroma_ids, + metadatas=metadatas, + documents=documents + ) + self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") + + def search_sync(self, query_embedding: np.ndarray, k: int = 5): + """ + Searches the Chroma collection for the top-k nearest neighbors. + Args: + query_embedding: A numpy array of the query embedding. + k: The number of results to return. + Returns: + A dictionary containing query results from Chroma. + """ + self.logger.debug(f"Searching Chroma collection '{self._collection_name}' with k={k}.") + results = self.collection.query( + query_embeddings=query_embedding.tolist(), + n_results=k, + # REMOVE 'ids' from the include list. It's returned by default. + include=["metadatas", "distances", "documents"] + ) + self.logger.debug(f"Chroma search returned {len(results.get('ids', [[]])[0])} results.") + return results + + def delete_by_file_id_sync(self, file_id: int): + self.logger.info(f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'.") + self.collection.delete(where={"file_id": file_id}) + self.logger.info(f"Deleted embeddings for file_id: {file_id} from Chroma.") \ No newline at end of file diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py new file mode 100644 index 00000000..f115dac4 --- /dev/null +++ b/pkg/rag/knowledge/services/chunker.py @@ -0,0 +1,63 @@ +# services/chunker.py +import logging +from typing import List +from services.base_service import BaseService # Assuming BaseService provides _run_sync + +logger = logging.getLogger(__name__) + +class Chunker(BaseService): + """ + A class for splitting long texts into smaller, overlapping chunks. + """ + def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50): + super().__init__() # Initialize BaseService + self.logger = logging.getLogger(self.__class__.__name__) + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + if self.chunk_overlap >= self.chunk_size: + self.logger.warning("Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.") + + def _split_text_sync(self, text: str) -> List[str]: + """ + Synchronously splits a long text into chunks with specified overlap. + This is a CPU-bound operation, intended to be run in a separate thread. + """ + if not text: + return [] + + # Simple whitespace-based splitting for demonstration + # For more advanced chunking, consider libraries like LangChain's text splitters + words = text.split() + chunks = [] + current_chunk = [] + + for word in words: + current_chunk.append(word) + if len(current_chunk) > self.chunk_size: + chunks.append(" ".join(current_chunk[:self.chunk_size])) + current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:] + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + # A more robust chunking strategy (e.g., using recursive character text splitter) + # from langchain.text_splitter import RecursiveCharacterTextSplitter + # text_splitter = RecursiveCharacterTextSplitter( + # chunk_size=self.chunk_size, + # chunk_overlap=self.chunk_overlap, + # length_function=len, + # is_separator_regex=False, + # ) + # return text_splitter.split_text(text) + + return [chunk for chunk in chunks if chunk.strip()] # Filter out empty chunks + + async def chunk(self, text: str) -> List[str]: + """ + Asynchronously chunks a given text into smaller pieces. + """ + self.logger.info(f"Chunking text (length: {len(text)})...") + # Run the synchronous splitting logic in a separate thread + chunks = await self._run_sync(self._split_text_sync, text) + self.logger.info(f"Text chunked into {len(chunks)} pieces.") + return chunks \ No newline at end of file diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py new file mode 100644 index 00000000..4ec21af3 --- /dev/null +++ b/pkg/rag/knowledge/services/database.py @@ -0,0 +1,57 @@ +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary +from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from datetime import datetime +import numpy as np # 用于处理从LargeBinary转换回来的embedding + +Base = declarative_base() + +class KnowledgeBase(Base): + __tablename__ = 'kb' + id = Column(Integer, primary_key=True, index=True) + name = Column(String, index=True) + description = Column(Text) + created_at = Column(DateTime, default=datetime.utcnow) + + files = relationship("File", back_populates="knowledge_base") + +class File(Base): + __tablename__ = 'file' + id = Column(Integer, primary_key=True, index=True) + kb_id = Column(Integer, ForeignKey('kb.id')) + file_name = Column(String) + path = Column(String) + created_at = Column(DateTime, default=datetime.utcnow) + file_type = Column(String) + status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误 + knowledge_base = relationship("KnowledgeBase", back_populates="files") + chunks = relationship("Chunk", back_populates="file") + +class Chunk(Base): + __tablename__ = 'chunks' + id = Column(Integer, primary_key=True, index=True) + file_id = Column(Integer, ForeignKey('file.id')) + text = Column(Text) + + file = relationship("File", back_populates="chunks") + vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one + +class Vector(Base): + __tablename__ = 'vectors' + id = Column(Integer, primary_key=True, index=True) + chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) + embedding = Column(LargeBinary) # Store embeddings as binary + + chunk = relationship("Chunk", back_populates="vector") + +# 数据库连接 +DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL +engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {}) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建所有表 (可以在应用启动时执行一次) +def create_db_and_tables(): + Base.metadata.create_all(bind=engine) + print("Database tables created/checked.") + +# 定义嵌入维度(请根据你实际使用的模型调整) +EMBEDDING_DIM = 1024 \ No newline at end of file diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py new file mode 100644 index 00000000..2b581e96 --- /dev/null +++ b/pkg/rag/knowledge/services/embedder.py @@ -0,0 +1,93 @@ +# services/embedder.py +import asyncio +import logging +import numpy as np +from typing import List +from sqlalchemy.orm import Session +from services.base_service import BaseService +from services.database import Chunk, SessionLocal +from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from services.chroma_manager import ChromaIndexManager # Import the manager + +logger = logging.getLogger(__name__) + +class Embedder(BaseService): + def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): + super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.model_type = model_type + self.model_name_key = model_name_key + self.chroma_manager = chroma_manager # Dependency Injection + + self.embedding_model: BaseEmbeddingModel = self._load_embedding_model() + + def _load_embedding_model(self) -> BaseEmbeddingModel: + self.logger.info(f"Loading embedding model: type={self.model_type}, name_key={self.model_name_key}...") + try: + model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) + self.logger.info(f"Embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}") + return model + except Exception as e: + self.logger.error(f"Failed to load embedding model '{self.model_name_key}': {e}") + raise + + def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]): + """ + Saves chunks to the relational database and returns the created Chunk objects. + This function assumes it's called within a context where the session + will be committed/rolled back and closed by the caller. + """ + self.logger.debug(f"Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).") + chunk_objects = [] + for text in chunks_texts: + chunk = Chunk(file_id=file_id, text=text) + session.add(chunk) + chunk_objects.append(chunk) + session.flush() # This populates the .id attribute for each new chunk object + self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.") + return chunk_objects + + async def embed_and_store(self, file_id: int, chunks: List[str]): + if not self.embedding_model: + raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.") + + self.logger.info(f"Embedding {len(chunks)} chunks for file_id: {file_id} using {self.model_name_key}...") + + session = SessionLocal() # Start a session that will live for the whole operation + chunk_objects = [] + try: + # 1. Save chunks to the relational database first to get their IDs + # We call _db_save_chunks_sync directly without _run_sync's session management + # because we manage the session here across multiple async calls. + chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks) + session.commit() # Commit chunks to make their IDs permanent and accessible + + if not chunk_objects: + self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.") + return [] + + # 2. Generate embeddings + embeddings: List[List[float]] = await self.embedding_model.embed_documents(chunks) + embeddings_np = np.array(embeddings, dtype=np.float32) + + if embeddings_np.shape[1] != self.embedding_model.embedding_dimension: + self.logger.error(f"Mismatch in embedding dimension: Model returned {embeddings_np.shape[1]}, expected {self.embedding_model.embedding_dimension}. Aborting storage.") + raise ValueError("Embedding dimension mismatch during embedding process.") + + self.logger.info("Saving embeddings to Chroma...") + chunk_ids = [c.id for c in chunk_objects] # Now safe to access .id because session is still open and committed + file_ids_for_chroma = [file_id] * len(chunk_ids) + + await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call + self.chroma_manager.add_embeddings_sync, + file_ids_for_chroma, chunk_ids, embeddings_np, chunks # Pass original chunks texts for documents + ) + self.logger.info(f"Successfully saved {len(chunk_objects)} embeddings to Chroma.") + return chunk_objects + + except Exception as e: + session.rollback() # Rollback on any error + self.logger.error(f"Failed to process and store data for file_id {file_id}: {e}", exc_info=True) + raise # Re-raise the exception to propagate it + finally: + session.close() # Ensure the session is always closed \ No newline at end of file diff --git a/pkg/rag/knowledge/services/embedding_models.py b/pkg/rag/knowledge/services/embedding_models.py new file mode 100644 index 00000000..a6ce73ae --- /dev/null +++ b/pkg/rag/knowledge/services/embedding_models.py @@ -0,0 +1,223 @@ +# services/embedding_models.py + +import os +from typing import Dict, Any, List, Type, Optional +import logging +import aiohttp # Import aiohttp for asynchronous requests +import asyncio +from sentence_transformers import SentenceTransformer + +logger = logging.getLogger(__name__) + +# Base class for all embedding models +class BaseEmbeddingModel: + def __init__(self, model_name: str): + self.model_name = model_name + self._embedding_dimension = None + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronously embeds a list of texts.""" + raise NotImplementedError + + async def embed_query(self, text: str) -> List[float]: + """Asynchronously embeds a single query text.""" + raise NotImplementedError + + @property + def embedding_dimension(self) -> int: + """Returns the embedding dimension of the model.""" + if self._embedding_dimension is None: + raise NotImplementedError("Embedding dimension not set for this model.") + return self._embedding_dimension + +class EmbeddingModelFactory: + @staticmethod + def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel: + """ + Factory method to create an embedding model instance. + Currently only supports 'third_party_api' types. + """ + if model_name_key not in EMBEDDING_MODEL_CONFIGS: + raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.") + + config = EMBEDDING_MODEL_CONFIGS[model_name_key] + + if config['type'] == "third_party_api": + required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension'] + if not all(key in config for key in required_keys): + raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}") + + # Retrieve model_name from config if it differs from model_name_key + # Some APIs expect a specific 'model' value in the payload that might be different from the key + api_model_name = config.get('model_name', model_name_key) + + return ThirdPartyAPIEmbeddingModel( + model_name=api_model_name, # Use the model_name from config or the key + api_endpoint=config['api_endpoint'], + headers=config['headers'], + payload_template=config['payload_template'], + embedding_dimension=config['embedding_dimension'] + ) + +class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): + def __init__(self, model_name: str): + super().__init__(model_name) + try: + # SentenceTransformer is inherently synchronous, but we'll wrap its calls + # in async methods. The actual computation will still block the event loop + # if not run in a separate thread/process, but this keeps the API consistent. + self.model = SentenceTransformer(model_name) + self._embedding_dimension = self.model.get_sentence_embedding_dimension() + logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}") + except Exception as e: + logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}") + raise + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + # For CPU-bound tasks like local model inference, consider running in a thread pool + # to prevent blocking the event loop for long operations. + # For simplicity here, we'll call it directly. + return self.model.encode(texts).tolist() + + async def embed_query(self, text: str) -> List[float]: + return self.model.encode(text).tolist() + + +class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): + def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int): + super().__init__(model_name) + self.api_endpoint = api_endpoint + self.headers = headers + self.payload_template = payload_template + self._embedding_dimension = embedding_dimension + self.session = None # aiohttp client session will be initialized on first use or in a context manager + logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}") + + async def _get_session(self): + """Lazily create or return the aiohttp client session.""" + if self.session is None or self.session.closed: + self.session = aiohttp.ClientSession() + return self.session + + async def close_session(self): + """Explicitly close the aiohttp client session.""" + if self.session and not self.session.closed: + await self.session.close() + self.session = None + logger.info(f"Closed aiohttp session for model {self.model_name}") + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronously embeds a list of texts using the third-party API.""" + session = await self._get_session() + embeddings = [] + tasks = [] + for text in texts: + payload = self.payload_template.copy() + if 'input' in payload: + payload['input'] = text + elif 'texts' in payload: + payload['texts'] = [text] + else: + raise ValueError("Payload template does not contain expected text input key.") + + tasks.append(self._make_api_request(session, payload)) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, res in enumerate(results): + if isinstance(res, Exception): + logger.error(f"Error embedding text '{texts[i][:50]}...': {res}") + # Depending on your error handling strategy, you might: + # - Append None or an empty list + # - Re-raise the exception to stop processing + # - Log and skip, then continue + embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure + else: + embeddings.append(res) + + return embeddings + + async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]: + """Helper to make an asynchronous API request and extract embedding.""" + try: + async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response: + response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx) + api_response = await response.json() + + # Adjust this based on your API's actual response structure + if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]: + embedding = api_response["data"][0]["embedding"] + if len(embedding) != self.embedding_dimension: + logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + return embedding + elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]: + embedding = api_response["embeddings"][0] + if len(embedding) != self.embedding_dimension: + logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + return embedding + else: + raise ValueError(f"Unexpected API response structure: {api_response}") + + except aiohttp.ClientError as e: + raise ConnectionError(f"API request failed: {e}") from e + except ValueError as e: + raise ValueError(f"Error processing API response: {e}") from e + + + async def embed_query(self, text: str) -> List[float]: + """Asynchronously embeds a single query text.""" + results = await self.embed_documents([text]) + if results: + return results[0] + return [] # Or raise an error if embedding a query must always succeed + +# --- Embedding Model Configuration --- +EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { + "MiniLM": { # Example for a local Sentence Transformer model + "type": "sentence_transformer", + "model_name": "sentence-transformers/all-MiniLM-L6-v2" + }, + "bge-m3": { # Example for a third-party API model + "type": "third_party_api", + "model_name": "bge-m3", + "api_endpoint": "https://api.qhaigc.net/v1/embeddings", + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.getenv('rag_api_key')}" + }, + "payload_template": { + "model": "bge-m3", + "input": "" + }, + "embedding_dimension": 1024 + }, + "OpenAI-Ada-002": { + "type": "third_party_api", + "model_name": "text-embedding-ada-002", + "api_endpoint": "https://api.openai.com/v1/embeddings", + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set + }, + "payload_template": { + "model": "text-embedding-ada-002", + "input": "" # Text will be injected here + }, + "embedding_dimension": 1536 + }, + "OpenAI-Embedding-3-Small": { + "type": "third_party_api", + "model_name": "text-embedding-3-small", + "api_endpoint": "https://api.openai.com/v1/embeddings", + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" + }, + "payload_template": { + "model": "text-embedding-3-small", + "input": "", + # "dimensions": 512 # Optional: uncomment if you want a specific output dimension + }, + "embedding_dimension": 1536 # Default max dimension for text-embedding-3-small + }, +} \ No newline at end of file diff --git a/pkg/rag/knowledge/services/parser.py b/pkg/rag/knowledge/services/parser.py new file mode 100644 index 00000000..5fa7d589 --- /dev/null +++ b/pkg/rag/knowledge/services/parser.py @@ -0,0 +1,288 @@ + +import PyPDF2 +from docx import Document +import pandas as pd +import csv +import chardet +from typing import Union, List, Callable, Any +import logging +import markdown +from bs4 import BeautifulSoup +import ebooklib +from ebooklib import epub +import re +import asyncio # Import asyncio for async operations +import os + +# Configure logging +logger = logging.getLogger(__name__) + +class FileParser: + """ + A robust file parser class to extract text content from various document formats. + It supports TXT, PDF, DOCX, XLSX, CSV, Markdown, HTML, and EPUB files. + All core file reading operations are designed to be run synchronously in a thread pool + to avoid blocking the asyncio event loop. + """ + def __init__(self): + + self.logger = logging.getLogger(self.__class__.__name__) + + async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any: + """ + Runs a synchronous function in a separate thread to prevent blocking the event loop. + This is a general utility method for wrapping blocking I/O operations. + """ + try: + return await asyncio.to_thread(sync_func, *args, **kwargs) + except Exception as e: + self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}") + raise + + async def parse(self, file_path: str) -> Union[str, None]: + """ + Parses the file based on its extension and returns the extracted text content. + This is the main asynchronous entry point for parsing. + + Args: + file_path (str): The path to the file to be parsed. + + Returns: + Union[str, None]: The extracted text content as a single string, or None if parsing fails. + """ + if not file_path or not os.path.exists(file_path): + self.logger.error(f"Invalid file path provided: {file_path}") + return None + + file_extension = file_path.split('.')[-1].lower() + parser_method = getattr(self, f'_parse_{file_extension}', None) + + if parser_method is None: + self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}") + return None + + try: + # Pass file_path to the specific parser methods + return await parser_method(file_path) + except Exception as e: + self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}") + return None + + # --- Helper for reading files with encoding detection --- + async def _read_file_content(self, file_path: str, mode: str = 'r') -> Union[str, bytes]: + """ + Reads a file with automatic encoding detection, ensuring the synchronous + file read operation runs in a separate thread. + """ + def _read_sync(): + with open(file_path, 'rb') as file: + raw_data = file.read() + detected = chardet.detect(raw_data) + encoding = detected['encoding'] or 'utf-8' + + if mode == 'r': + return raw_data.decode(encoding, errors='ignore') + return raw_data # For binary mode + + return await self._run_sync(_read_sync) + + # --- Specific Parser Methods --- + + async def _parse_txt(self, file_path: str) -> str: + """Parses a TXT file and returns its content.""" + self.logger.info(f"Parsing TXT file: {file_path}") + return await self._read_file_content(file_path, mode='r') + + async def _parse_pdf(self, file_path: str) -> str: + """Parses a PDF file and returns its text content.""" + self.logger.info(f"Parsing PDF file: {file_path}") + def _parse_pdf_sync(): + text_content = [] + with open(file_path, 'rb') as file: + pdf_reader = PyPDF2.PdfReader(file) + for page in pdf_reader.pages: + text = page.extract_text() + if text: + text_content.append(text) + return '\n'.join(text_content) + return await self._run_sync(_parse_pdf_sync) + + async def _parse_docx(self, file_path: str) -> str: + """Parses a DOCX file and returns its text content.""" + self.logger.info(f"Parsing DOCX file: {file_path}") + def _parse_docx_sync(): + doc = Document(file_path) + text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()] + return '\n'.join(text_content) + return await self._run_sync(_parse_docx_sync) + + async def _parse_doc(self, file_path: str) -> str: + """Handles .doc files, explicitly stating lack of direct support.""" + self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.") + raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.") + + async def _parse_xlsx(self, file_path: str) -> str: + """Parses an XLSX file, returning text from all sheets.""" + self.logger.info(f"Parsing XLSX file: {file_path}") + def _parse_xlsx_sync(): + excel_file = pd.ExcelFile(file_path) + all_sheet_content = [] + for sheet_name in excel_file.sheet_names: + df = pd.read_excel(file_path, sheet_name=sheet_name) + sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n" + all_sheet_content.append(sheet_text) + return '\n'.join(all_sheet_content) + return await self._run_sync(_parse_xlsx_sync) + + async def _parse_csv(self, file_path: str) -> str: + """Parses a CSV file and returns its content as a string.""" + self.logger.info(f"Parsing CSV file: {file_path}") + def _parse_csv_sync(): + # pd.read_csv can often detect encoding, but explicit detection is safer + raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function + # For simplicity, we'll let pandas handle encoding internally after a raw read. + # A more robust solution might pass encoding directly to pd.read_csv after detection. + detected = chardet.detect(open(file_path, 'rb').read()) + encoding = detected['encoding'] or 'utf-8' + df = pd.read_csv(file_path, encoding=encoding) + return df.to_string(index=False) + return await self._run_sync(_parse_csv_sync) + + async def _parse_markdown(self, file_path: str) -> str: + """Parses a Markdown file, converting it to structured plain text.""" + self.logger.info(f"Parsing Markdown file: {file_path}") + def _parse_markdown_sync(): + md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function + html_content = markdown.markdown( + md_content, + extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code'] + ) + soup = BeautifulSoup(html_content, 'html.parser') + text_parts = [] + for element in soup.children: + if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: + level = int(element.name[1]) + text_parts.append('#' * level + ' ' + element.get_text().strip()) + elif element.name == 'p': + text = element.get_text().strip() + if text: + text_parts.append(text) + elif element.name in ['ul', 'ol']: + for li in element.find_all('li'): + text_parts.append(f"* {li.get_text().strip()}") + elif element.name == 'pre': + code_block = element.get_text().strip() + if code_block: + text_parts.append(f"```\n{code_block}\n```") + elif element.name == 'table': + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + if table_str: + text_parts.append(table_str) + elif element.name: + text = element.get_text(separator=' ', strip=True) + if text: + text_parts.append(text) + cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) + return cleaned_text.strip() + return await self._run_sync(_parse_markdown_sync) + + async def _parse_html(self, file_path: str) -> str: + """Parses an HTML file, extracting structured plain text.""" + self.logger.info(f"Parsing HTML file: {file_path}") + def _parse_html_sync(): + html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function + soup = BeautifulSoup(html_content, 'html.parser') + for script_or_style in soup(["script", "style"]): + script_or_style.decompose() + text_parts = [] + for element in soup.body.children if soup.body else soup.children: + if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: + level = int(element.name[1]) + text_parts.append('#' * level + ' ' + element.get_text().strip()) + elif element.name == 'p': + text = element.get_text().strip() + if text: + text_parts.append(text) + elif element.name in ['ul', 'ol']: + for li in element.find_all('li'): + text = li.get_text().strip() + if text: + text_parts.append(f"* {text}") + elif element.name == 'table': + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + if table_str: + text_parts.append(table_str) + elif element.name: + text = element.get_text(separator=' ', strip=True) + if text: + text_parts.append(text) + cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) + return cleaned_text.strip() + return await self._run_sync(_parse_html_sync) + + async def _parse_epub(self, file_path: str) -> str: + """Parses an EPUB file, extracting metadata and content.""" + self.logger.info(f"Parsing EPUB file: {file_path}") + def _parse_epub_sync(): + book = epub.read_epub(file_path) + text_content = [] + title_meta = book.get_metadata('DC', 'title') + if title_meta: + text_content.append(f"Title: {title_meta[0][0]}") + creator_meta = book.get_metadata('DC', 'creator') + if creator_meta: + text_content.append(f"Author: {creator_meta[0][0]}") + date_meta = book.get_metadata('DC', 'date') + if date_meta: + text_content.append(f"Publish Date: {date_meta[0][0]}") + toc = book.get_toc() + if toc: + text_content.append("\n--- Table of Contents ---") + self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper + text_content.append("--- End of Table of Contents ---\n") + for item in book.get_items(): + if item.get_type() == ebooklib.ITEM_DOCUMENT: + html_content = item.get_content().decode('utf-8', errors='ignore') + soup = BeautifulSoup(html_content, 'html.parser') + for junk in soup(["script", "style", "nav", "header", "footer"]): + junk.decompose() + text = soup.get_text(separator='\n', strip=True) + text = re.sub(r'\n\s*\n', '\n\n', text) + if text: + text_content.append(text) + return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip() + return await self._run_sync(_parse_epub_sync) + + def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int): + """Recursively adds TOC items to text_content (synchronous helper).""" + indent = ' ' * level + for item in toc_list: + if isinstance(item, tuple): + chapter, subchapters = item + text_content.append(f"{indent}- {chapter.title}") + self._add_toc_items_sync(subchapters, text_content, level + 1) + else: + text_content.append(f"{indent}- {item.title}") + + def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str: + """Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous).""" + headers = [th.get_text().strip() for th in table_element.find_all('th')] + rows = [] + for tr in table_element.find_all('tr'): + cells = [td.get_text().strip() for td in tr.find_all('td')] + if cells: + rows.append(cells) + + if not headers and not rows: + return "" + + table_lines = [] + if headers: + table_lines.append(' | '.join(headers)) + table_lines.append(' | '.join(['---'] * len(headers))) + + for row_cells in rows: + padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells + table_lines.append(' | '.join(padded_cells)) + + return '\n'.join(table_lines) \ No newline at end of file diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py new file mode 100644 index 00000000..6da1c5d8 --- /dev/null +++ b/pkg/rag/knowledge/services/retriever.py @@ -0,0 +1,106 @@ +# services/retriever.py +import asyncio +import logging +import numpy as np # Make sure numpy is imported +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from services.base_service import BaseService +from services.database import Chunk, SessionLocal +from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from services.chroma_manager import ChromaIndexManager + +logger = logging.getLogger(__name__) + +class Retriever(BaseService): + def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): + super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.model_type = model_type + self.model_name_key = model_name_key + self.chroma_manager = chroma_manager + + self.embedding_model: BaseEmbeddingModel = self._load_embedding_model() + + def _load_embedding_model(self) -> BaseEmbeddingModel: + self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...") + try: + model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) + self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}") + return model + except Exception as e: + self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}") + raise + + async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: + if not self.embedding_model: + raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.") + + self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}") + + query_embedding: List[float] = await self.embedding_model.embed_query(query) + query_embedding_np = np.array([query_embedding], dtype=np.float32) + + chroma_results = await self._run_sync( + self.chroma_manager.search_sync, + query_embedding_np, k + ) + + # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' + matched_chroma_ids = chroma_results.get("ids", [[]])[0] + distances = chroma_results.get("distances", [[]])[0] + chroma_metadatas = chroma_results.get("metadatas", [[]])[0] + chroma_documents = chroma_results.get("documents", [[]])[0] + + if not matched_chroma_ids: + self.logger.info("No relevant chunks found in Chroma.") + return [] + + db_chunk_ids = [] + for metadata in chroma_metadatas: + if "chunk_id" in metadata: + db_chunk_ids.append(metadata["chunk_id"]) + else: + self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.") + + if not db_chunk_ids: + self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.") + return [] + + self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...") + chunks_from_db = await self._run_sync( + lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync + db_chunk_ids + ) + + chunk_map = {chunk.id: chunk for chunk in chunks_from_db} + results_list: List[Dict[str, Any]] = [] + + for i, chroma_id in enumerate(matched_chroma_ids): + try: + # Ensure original_chunk_id is int for DB lookup + original_chunk_id = int(chroma_id.split('_')[-1]) + except (ValueError, IndexError): + self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.") + continue + + chunk_text_from_chroma = chroma_documents[i] + distance = float(distances[i]) + file_id_from_chroma = chroma_metadatas[i].get("file_id") + + chunk_from_db = chunk_map.get(original_chunk_id) + + results_list.append({ + "chunk_id": original_chunk_id, + "text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma, + "distance": distance, + "file_id": file_id_from_chroma + }) + + self.logger.info(f"Retrieved {len(results_list)} chunks.") + return results_list + + def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]: + self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).") + chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all() + session.close() + return chunks \ No newline at end of file diff --git a/pkg/rag/knowledge/utils/crawler.py b/pkg/rag/knowledge/utils/crawler.py new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index 5e85bfb0..27a03a92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,17 @@ dependencies = [ "ruff>=0.11.9", "pre-commit>=4.2.0", "uv>=0.7.11", + "PyPDF2>=3.0.1", + "python-docx>=1.1.0", + "pandas>=2.2.2", + "chardet>=5.2.0", + "markdown>=3.6", + "beautifulsoup4>=4.12.3", + "ebooklib>=0.18", + "html2text>=2024.2.26", + "langchain>=0.2.0", + "chromadb>=0.4.24", + "sentence-transformers>=2.6.1", ] keywords = [ "bot", From c4671fbf1c046c83f8478bebb95af8865f9f6cf8 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Thu, 26 Jun 2025 14:09:26 +0800 Subject: [PATCH 5/7] delete ap --- pkg/rag/knowledge/RAG_Manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index e172c132..d85699d3 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -21,7 +21,6 @@ class RAG_Manager: self.chunker = None self.embedder = None self.retriever = None - self.ap = app.Application async def initialize_system(self): await asyncio.to_thread(create_db_and_tables) From 34fe8b324d0b5e1412641ef3b6029372c2327a90 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Thu, 3 Jul 2025 23:28:47 +0800 Subject: [PATCH 6/7] feat: add functions --- .../http/controller/groups/knowledge_base.py | 27 ++++++++-------- .../controller/groups/pipelines/pipelines.py | 2 +- pkg/core/app.py | 10 +++--- pkg/core/entities.py | 2 +- pkg/core/stages/build_app.py | 7 ++++ pkg/entity/persistence/vector.py | 14 ++++++++ pkg/rag/knowledge/RAG_Manager.py | 32 +++++++++++-------- pkg/rag/knowledge/services/base_service.py | 8 ++--- pkg/rag/knowledge/services/chroma_manager.py | 4 +-- pkg/rag/knowledge/services/chunker.py | 2 +- pkg/rag/knowledge/services/embedder.py | 8 ++--- pkg/rag/knowledge/services/retriever.py | 8 ++--- 12 files changed, 75 insertions(+), 49 deletions(-) create mode 100644 pkg/entity/persistence/vector.py diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge_base.py index c819397a..f9aa09e0 100644 --- a/pkg/api/http/controller/groups/knowledge_base.py +++ b/pkg/api/http/controller/groups/knowledge_base.py @@ -1,5 +1,4 @@ import quart -from __future__ import annotations from .. import group @group.group_class('knowledge_base', '/api/v1/knowledge/bases') @@ -16,13 +15,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): async def initialize(self) -> None: - rag = self.ap.knowledge_base_service.RAG_Manager() + @self.route('', methods=['POST', 'GET']) async def _() -> str: if quart.request.method == 'GET': - knowledge_bases = await rag.get_all_knowledge_bases() + knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases() bases_list = [ { "uuid": kb.id, @@ -35,17 +34,19 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): msg='ok') json_data = await quart.request.json - knowledge_base_uuid = await rag.create_knowledge_base( + knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( json_data.get('name'), json_data.get('description') ) - return self.success() + return self.success(code=0, + data={}, + msg='ok') - @self.route('/', methods=['GET']) + @self.route('/', methods=['GET','DELETE']) async def _(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - knowledge_base = await rag.get_knowledge_base_by_id(knowledge_base_uuid) + knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid) if knowledge_base is None: return self.http_status(404, -1, 'knowledge base not found') @@ -59,11 +60,14 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): }, msg='ok' ) + elif quart.request.method == 'DELETE': + await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid) + return self.success(code=0, msg='ok') @self.route('//files', methods=['GET']) async def _(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - files = await rag.get_files_by_knowledge_base(knowledge_base_uuid) + files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) return self.success(code=0,data=[{ "id": file.id, "file_name": file.file_name, @@ -73,11 +77,6 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): # delete specific file in knowledge base @self.route('//files/', methods=['DELETE']) async def _(knowledge_base_uuid: str, file_id: str) -> str: - await rag.delete_data_by_file_id(file_id) + await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) return self.success(code=0, msg='ok') - # delete specific kb - @self.route('/', methods=['DELETE']) - async def _(knowledge_base_uuid: str) -> str: - await rag.delete_kb_by_id(knowledge_base_uuid) - return self.success(code=0, msg='ok') diff --git a/pkg/api/http/controller/groups/pipelines/pipelines.py b/pkg/api/http/controller/groups/pipelines/pipelines.py index 96ca239a..1a8036cc 100644 --- a/pkg/api/http/controller/groups/pipelines/pipelines.py +++ b/pkg/api/http/controller/groups/pipelines/pipelines.py @@ -2,7 +2,7 @@ from __future__ import annotations import quart -from ... import group +from .. import group @group.group_class('pipelines', '/api/v1/pipelines') diff --git a/pkg/core/app.py b/pkg/core/app.py index d8824466..2e3c9500 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -27,10 +27,7 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities -from ...pkg.rag.knowledge import RAG_Manager - - - +from pkg.rag.knowledge.RAG_Manager import RAG_Manager class Application: @@ -51,6 +48,7 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None + # TODO 移动到 pipeline 里 tool_mgr: llm_tool_mgr.ToolManager = None @@ -103,7 +101,6 @@ class Application: storage_mgr: storagemgr.StorageMgr = None - knowledge_base_service: RAG_Manager = None # ========= HTTP Services ========= @@ -117,6 +114,8 @@ class Application: bot_service: bot_service.BotService = None + knowledge_base_service: RAG_Manager = None + def __init__(self): pass @@ -152,6 +151,7 @@ class Application: name='http-api-controller', scopes=[core_entities.LifecycleControlScope.APPLICATION], ) + self.task_mgr.create_task( never_ending(), name='never-ending-task', diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 4caf18ed..8dc51e5b 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -19,7 +19,7 @@ class LifecycleControlScope(enum.Enum): APPLICATION = 'application' PLATFORM = 'platform' PLUGIN = 'plugin' - PROVIDER = 'provider' + PROVIDER = 'provider' class LauncherTypes(enum.Enum): diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 482a468b..3ba468c8 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -9,6 +9,7 @@ from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.tools import toolmgr as llm_tool_mgr +from ...rag.knowledge.RAG_Manager import RAG_Manager as knowledge_base_mgr from ...platform import botmgr as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -101,6 +102,12 @@ class BuildAppStage(stage.BootingStage): embedding_models_service_inst = model_service.EmbeddingModelsService(ap) ap.embedding_models_service = embedding_models_service_inst + knowledge_base_service_inst = knowledge_base_mgr(ap) + print("knowledge_base_service_inst1", type(knowledge_base_service_inst)) + await knowledge_base_service_inst.initialize_rag_system() + ap.knowledge_base_service = knowledge_base_service_inst + print("knowledge_base_service_inst", type(ap.knowledge_base_service)) + pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst diff --git a/pkg/entity/persistence/vector.py b/pkg/entity/persistence/vector.py new file mode 100644 index 00000000..84d1dfb1 --- /dev/null +++ b/pkg/entity/persistence/vector.py @@ -0,0 +1,14 @@ +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary +from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from datetime import datetime +import numpy as np # 用于处理从LargeBinary转换回来的embedding + +Base = declarative_base() + +class Vector(Base): + __tablename__ = 'vectors' + id = Column(Integer, primary_key=True, index=True) + chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) + embedding = Column(LargeBinary) # Store embeddings as binary + + chunk = relationship("Chunk", back_populates="vector") \ No newline at end of file diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index d85699d3..292f23ce 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -1,18 +1,24 @@ # RAG_Manager class (main class, adjust imports as needed) +from __future__ import annotations # For type hinting in Python 3.7+ import logging import os import asyncio -from services.parser import FileParser -from services.chunker import Chunker -from services.embedder import Embedder -from services.retriever import Retriever -from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly -from services.embedding_models import EmbeddingModelFactory -from services.chroma_manager import ChromaIndexManager -from ...core import app +from pkg.rag.knowledge.services.parser import FileParser +from pkg.rag.knowledge.services.chunker import Chunker +from pkg.rag.knowledge.services.embedder import Embedder +from pkg.rag.knowledge.services.retriever import Retriever +from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly +from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager +from pkg.core import app # Adjust the import path as needed + class RAG_Manager: - def __init__(self, logger: logging.Logger = None): + + ap: app.Application + + def __init__(self, ap: app.Application,logger: logging.Logger = None): + self.ap = ap self.logger = logger or logging.getLogger(__name__) self.embedding_model_type = None self.embedding_model_name = None @@ -21,11 +27,11 @@ class RAG_Manager: self.chunker = None self.embedder = None self.retriever = None - - async def initialize_system(self): + + async def initialize_rag_system(self): await asyncio.to_thread(create_db_and_tables) - async def create_model(self, embedding_model_type: str, + async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str): self.embedding_model_type = embedding_model_type self.embedding_model_name = embedding_model_name @@ -57,7 +63,7 @@ class RAG_Manager: ) - async def create_knowledge_base(self, kb_name: str, kb_description: str): + async def create_knowledge_base(self, kb_name: str, kb_description: str ,): """ Creates a new knowledge base with the given name and description. If a knowledge base with the same name already exists, it returns that one. diff --git a/pkg/rag/knowledge/services/base_service.py b/pkg/rag/knowledge/services/base_service.py index 0298226a..4ff1ce39 100644 --- a/pkg/rag/knowledge/services/base_service.py +++ b/pkg/rag/knowledge/services/base_service.py @@ -1,20 +1,20 @@ # 封装异步操作 import asyncio import logging -from services.database import SessionLocal # 导入 SessionLocal 工厂函数 +from pkg.rag.knowledge.services.database import SessionLocal class BaseService: def __init__(self): self.logger = logging.getLogger(self.__class__.__name__) - self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数 + self.db_session_factory = SessionLocal async def _run_sync(self, func, *args, **kwargs): """ 在单独的线程中运行同步函数。 如果第一个参数是 session,则在 to_thread 中获取新的 session。 """ - # 如果函数需要数据库会话作为第一个参数,我们在这里获取它 - if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头 + + if getattr(func, '__name__', '').startswith('_db_'): session = await asyncio.to_thread(self.db_session_factory) try: result = await asyncio.to_thread(func, session, *args, **kwargs) diff --git a/pkg/rag/knowledge/services/chroma_manager.py b/pkg/rag/knowledge/services/chroma_manager.py index 6a469168..f8020cdb 100644 --- a/pkg/rag/knowledge/services/chroma_manager.py +++ b/pkg/rag/knowledge/services/chroma_manager.py @@ -1,4 +1,4 @@ -# services/chroma_manager.py + import numpy as np import logging from chromadb import PersistentClient @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) class ChromaIndexManager: def __init__(self, collection_name: str = "default_collection"): self.logger = logging.getLogger(self.__class__.__name__) - chroma_data_path = "./chroma_data" + chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma")) os.makedirs(chroma_data_path, exist_ok=True) self.client = PersistentClient(path=chroma_data_path) self._collection_name = collection_name diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py index f115dac4..17202a7a 100644 --- a/pkg/rag/knowledge/services/chunker.py +++ b/pkg/rag/knowledge/services/chunker.py @@ -1,7 +1,7 @@ # services/chunker.py import logging from typing import List -from services.base_service import BaseService # Assuming BaseService provides _run_sync +from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync logger = logging.getLogger(__name__) diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 2b581e96..7e20b19a 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -4,10 +4,10 @@ import logging import numpy as np from typing import List from sqlalchemy.orm import Session -from services.base_service import BaseService -from services.database import Chunk, SessionLocal -from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory -from services.chroma_manager import ChromaIndexManager # Import the manager +from pkg.rag.knowledge.services.base_service import BaseService +from pkg.rag.knowledge.services.database import Chunk, SessionLocal +from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager logger = logging.getLogger(__name__) diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py index 6da1c5d8..4da81eb1 100644 --- a/pkg/rag/knowledge/services/retriever.py +++ b/pkg/rag/knowledge/services/retriever.py @@ -4,10 +4,10 @@ import logging import numpy as np # Make sure numpy is imported from typing import List, Dict, Any from sqlalchemy.orm import Session -from services.base_service import BaseService -from services.database import Chunk, SessionLocal -from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory -from services.chroma_manager import ChromaIndexManager +from pkg.rag.knowledge.services.base_service import BaseService +from pkg.rag.knowledge.services.database import Chunk, SessionLocal +from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager logger = logging.getLogger(__name__) From 552fee9bacf0ade89cabd77ee12f886c7d8693a1 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Sat, 5 Jul 2025 17:53:11 +0800 Subject: [PATCH 7/7] fix: modify rag database --- pkg/rag/knowledge/RAG_Manager.py | 6 +++--- pkg/rag/knowledge/services/database.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index 292f23ce..6ded737a 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -62,8 +62,8 @@ class RAG_Manager: chroma_manager=self.chroma_manager # Inject dependency ) - - async def create_knowledge_base(self, kb_name: str, kb_description: str ,): + + async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5): """ Creates a new knowledge base with the given name and description. If a knowledge base with the same name already exists, it returns that one. @@ -82,7 +82,7 @@ class RAG_Manager: def _add_kb_sync(): session = SessionLocal() try: - new_kb = KnowledgeBase(name=kb_name, description=kb_description) + new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k) session.add(new_kb) session.commit() session.refresh(new_kb) diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py index 4ec21af3..a8c35883 100644 --- a/pkg/rag/knowledge/services/database.py +++ b/pkg/rag/knowledge/services/database.py @@ -11,7 +11,8 @@ class KnowledgeBase(Base): name = Column(String, index=True) description = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) - + embedding_model = Column(String, default="") # 默认嵌入模型 + top_k = Column(Integer, default=5) # 默认返回的top_k数量 files = relationship("File", back_populates="knowledge_base") class File(Base):