From c81d5a1a49194e2e806801f53c32b405f12627e7 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 21 May 2025 12:42:39 +0800 Subject: [PATCH 01/60] feat: add embeddings model management (#1461) * feat: add embeddings model management backend support Co-Authored-By: Junyan Qin * feat: add embeddings model management frontend support Co-Authored-By: Junyan Qin * chore: revert HttpClient URL to production setting Co-Authored-By: Junyan Qin * refactor: integrate embeddings models into models page with tabs Co-Authored-By: Junyan Qin * perf: move files * perf: remove `s` * feat: allow requester to declare supported types in manifest * feat(embedding): delete dimension and encoding format * feat: add extra_args for embedding moels * perf: i18n ref * fix: linter err * fix: lint err * fix: linter err --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Junyan Qin --- .../http/controller/groups/provider/models.py | 55 +- .../controller/groups/provider/requesters.py | 3 +- pkg/api/http/service/model.py | 89 ++- pkg/core/app.py | 4 +- pkg/core/stages/build_app.py | 7 +- pkg/entity/persistence/model.py | 21 + pkg/provider/modelmgr/entities.py | 2 +- pkg/provider/modelmgr/modelmgr.py | 89 ++- pkg/provider/modelmgr/requester.py | 51 +- .../modelmgr/requesters/anthropicmsgs.py | 2 +- .../modelmgr/requesters/anthropicmsgs.yaml | 2 + .../modelmgr/requesters/bailianchatcmpl.yaml | 2 + pkg/provider/modelmgr/requesters/chatcmpl.py | 38 +- .../modelmgr/requesters/chatcmpl.yaml | 3 + .../modelmgr/requesters/deepseekchatcmpl.yaml | 2 + .../modelmgr/requesters/geminichatcmpl.yaml | 2 + .../modelmgr/requesters/giteeaichatcmpl.yaml | 2 + .../modelmgr/requesters/lmstudiochatcmpl.yaml | 2 + .../modelmgr/requesters/modelscopechatcmpl.py | 2 +- .../requesters/modelscopechatcmpl.yaml | 2 + .../modelmgr/requesters/moonshotchatcmpl.yaml | 2 + .../modelmgr/requesters/ollamachat.py | 2 +- .../modelmgr/requesters/ollamachat.yaml | 2 + .../requesters/openrouterchatcmpl.yaml | 2 + .../modelmgr/requesters/ppiochatcmpl.yaml | 2 + .../requesters/siliconflowchatcmpl.yaml | 2 + .../modelmgr/requesters/volcarkchatcmpl.yaml | 2 + .../modelmgr/requesters/xaichatcmpl.yaml | 2 + .../modelmgr/requesters/zhipuaichatcmpl.yaml | 2 + .../home-sidebar/sidbarConfigList.tsx | 1 + .../{llm-form => }/ChooseRequesterEntity.ts | 0 .../models/component/ICreateEmbeddingField.ts | 7 + .../models/{ => component}/ICreateLLMField.ts | 0 .../embedding-card/EmbeddingCard.module.css | 97 +++ .../embedding-card/EmbeddingCard.tsx | 53 ++ .../embedding-card/EmbeddingCardVO.ts | 23 + .../embedding-form/EmbeddingForm.tsx | 563 ++++++++++++++++++ .../models/component/llm-form/LLMForm.tsx | 8 +- web/src/app/home/models/page.tsx | 183 +++++- web/src/app/infra/entities/api/index.ts | 23 + web/src/app/infra/http/HttpClient.ts | 42 +- web/src/i18n/locales/en-US.ts | 18 +- web/src/i18n/locales/zh-Hans.ts | 18 +- 43 files changed, 1370 insertions(+), 64 deletions(-) rename web/src/app/home/models/component/{llm-form => }/ChooseRequesterEntity.ts (100%) create mode 100644 web/src/app/home/models/component/ICreateEmbeddingField.ts rename web/src/app/home/models/{ => component}/ICreateLLMField.ts (100%) create mode 100644 web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css create mode 100644 web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx create mode 100644 web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts create mode 100644 web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx diff --git a/pkg/api/http/controller/groups/provider/models.py b/pkg/api/http/controller/groups/provider/models.py index bb77986c..0de0c922 100644 --- a/pkg/api/http/controller/groups/provider/models.py +++ b/pkg/api/http/controller/groups/provider/models.py @@ -9,18 +9,18 @@ class LLMModelsRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={'models': await self.ap.model_service.get_llm_models()}) + return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()}) elif quart.request.method == 'POST': json_data = await quart.request.json - model_uuid = await self.ap.model_service.create_llm_model(json_data) + model_uuid = await self.ap.llm_model_service.create_llm_model(json_data) return self.success(data={'uuid': model_uuid}) @self.route('/', methods=['GET', 'PUT', 'DELETE']) async def _(model_uuid: str) -> str: if quart.request.method == 'GET': - model = await self.ap.model_service.get_llm_model(model_uuid) + model = await self.ap.llm_model_service.get_llm_model(model_uuid) if model is None: return self.http_status(404, -1, 'model not found') @@ -29,11 +29,11 @@ class LLMModelsRouterGroup(group.RouterGroup): elif quart.request.method == 'PUT': json_data = await quart.request.json - await self.ap.model_service.update_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.update_llm_model(model_uuid, json_data) return self.success() elif quart.request.method == 'DELETE': - await self.ap.model_service.delete_llm_model(model_uuid) + await self.ap.llm_model_service.delete_llm_model(model_uuid) return self.success() @@ -41,6 +41,49 @@ class LLMModelsRouterGroup(group.RouterGroup): async def _(model_uuid: str) -> str: json_data = await quart.request.json - await self.ap.model_service.test_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.test_llm_model(model_uuid, json_data) + + return self.success() + + +@group.group_class('models/embedding', '/api/v1/provider/models/embedding') +class EmbeddingModelsRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('', methods=['GET', 'POST']) + async def _() -> str: + if quart.request.method == 'GET': + return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + + model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data) + + return self.success(data={'uuid': model_uuid}) + + @self.route('/', methods=['GET', 'PUT', 'DELETE']) + async def _(model_uuid: str) -> str: + if quart.request.method == 'GET': + model = await self.ap.embedding_models_service.get_embedding_model(model_uuid) + + if model is None: + return self.http_status(404, -1, 'model not found') + + return self.success(data={'model': model}) + elif quart.request.method == 'PUT': + json_data = await quart.request.json + + await self.ap.embedding_models_service.update_embedding_model(model_uuid, json_data) + + return self.success() + elif quart.request.method == 'DELETE': + await self.ap.embedding_models_service.delete_embedding_model(model_uuid) + + return self.success() + + @self.route('//test', methods=['POST']) + async def _(model_uuid: str) -> str: + json_data = await quart.request.json + + await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data) return self.success() diff --git a/pkg/api/http/controller/groups/provider/requesters.py b/pkg/api/http/controller/groups/provider/requesters.py index 0f999288..af9e1540 100644 --- a/pkg/api/http/controller/groups/provider/requesters.py +++ b/pkg/api/http/controller/groups/provider/requesters.py @@ -8,7 +8,8 @@ class RequestersRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> quart.Response: - return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()}) + model_type = quart.request.args.get('type', '') + return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info(model_type)}) @self.route('/', methods=['GET']) async def _(requester_name: str) -> quart.Response: diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 74fb4e02..afeae3eb 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -10,7 +10,7 @@ from ....provider.modelmgr import requester as model_requester from ....provider import entities as llm_entities -class ModelsService: +class LLMModelsService: ap: app.Application def __init__(self, ap: app.Application) -> None: @@ -103,3 +103,90 @@ class ModelsService: funcs=[], extra_args={}, ) + + +class EmbeddingModelsService: + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_embedding_models(self) -> list[dict]: + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + + models = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models] + + async def create_embedding_model(self, model_data: dict) -> str: + model_data['uuid'] = str(uuid.uuid4()) + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data) + ) + + embedding_model = await self.get_embedding_model(model_data['uuid']) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + return model_data['uuid'] + + async def get_embedding_model(self, model_uuid: str) -> dict | None: + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + model = result.first() + + if model is None: + return None + + return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) + + async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None: + if 'uuid' in model_data: + del model_data['uuid'] + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_model.EmbeddingModel) + .where(persistence_model.EmbeddingModel.uuid == model_uuid) + .values(**model_data) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + embedding_model = await self.get_embedding_model(model_uuid) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + async def delete_embedding_model(self, model_uuid: str) -> None: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None: + runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None + + if model_uuid != '_': + for model in self.ap.model_mgr.embedding_models: + if model.model_entity.uuid == model_uuid: + runtime_embedding_model = model + break + + if runtime_embedding_model is None: + raise Exception('model not found') + + else: + runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data) + + await runtime_embedding_model.requester.invoke_embedding( + query=None, + model=runtime_embedding_model, + input_text='Hello, world!', + extra_args={}, + ) diff --git a/pkg/core/app.py b/pkg/core/app.py index 911acd3d..318cddcb 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -103,7 +103,9 @@ class Application: user_service: user_service.UserService = None - model_service: model_service.ModelsService = None + llm_model_service: model_service.LLMModelsService = None + + embedding_models_service: model_service.EmbeddingModelsService = None pipeline_service: pipeline_service.PipelineService = None diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 6ee35610..482a468b 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -95,8 +95,11 @@ class BuildAppStage(stage.BootingStage): user_service_inst = user_service.UserService(ap) ap.user_service = user_service_inst - model_service_inst = model_service.ModelsService(ap) - ap.model_service = model_service_inst + llm_model_service_inst = model_service.LLMModelsService(ap) + ap.llm_model_service = llm_model_service_inst + + embedding_models_service_inst = model_service.EmbeddingModelsService(ap) + ap.embedding_models_service = embedding_models_service_inst pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst diff --git a/pkg/entity/persistence/model.py b/pkg/entity/persistence/model.py index 9eb2ccef..418cab70 100644 --- a/pkg/entity/persistence/model.py +++ b/pkg/entity/persistence/model.py @@ -23,3 +23,24 @@ class LLMModel(Base): server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now(), ) + + +class EmbeddingModel(Base): + """Embedding 模型""" + + __tablename__ = 'embedding_models' + + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) + extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/provider/modelmgr/entities.py b/pkg/provider/modelmgr/entities.py index cf856894..7bc02a32 100644 --- a/pkg/provider/modelmgr/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel): token_mgr: token.TokenManager - requester: requester.LLMAPIRequester + requester: requester.ProviderAPIRequester tool_call_supported: typing.Optional[bool] = False diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index b15e53a9..2c92eacc 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -18,7 +18,7 @@ class ModelManager: model_list: list[entities.LLMModelInfo] # deprecated - requesters: dict[str, requester.LLMAPIRequester] # deprecated + requesters: dict[str, requester.ProviderAPIRequester] # deprecated token_mgrs: dict[str, token.TokenManager] # deprecated @@ -28,9 +28,11 @@ class ModelManager: llm_models: list[requester.RuntimeLLMModel] + embedding_models: list[requester.RuntimeEmbeddingModel] + requester_components: list[engine.Component] - requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache + requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache def __init__(self, ap: app.Application): self.ap = ap @@ -38,6 +40,7 @@ class ModelManager: self.requesters = {} self.token_mgrs = {} self.llm_models = [] + self.embedding_models = [] self.requester_components = [] self.requester_dict = {} @@ -45,7 +48,7 @@ class ModelManager: self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') # forge requester class dict - requester_dict: dict[str, type[requester.LLMAPIRequester]] = {} + requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {} for component in self.requester_components: requester_dict[component.metadata.name] = component.get_python_component_class() @@ -58,13 +61,11 @@ class ModelManager: self.ap.logger.info('Loading models from db...') self.llm_models = [] + self.embedding_models = [] # llm models result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) - llm_models = result.all() - - # load models for llm_model in llm_models: try: await self.load_llm_model(llm_model) @@ -73,11 +74,17 @@ class ModelManager: except Exception as e: self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}') + # embedding models + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + embedding_models = result.all() + for embedding_model in embedding_models: + await self.load_embedding_model(embedding_model) + async def init_runtime_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """初始化运行时模型""" + """初始化运行时 LLM 模型""" if isinstance(model_info, sqlalchemy.Row): model_info = persistence_model.LLMModel(**model_info._mapping) elif isinstance(model_info, dict): @@ -101,14 +108,47 @@ class ModelManager: return runtime_llm_model + async def init_runtime_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """初始化运行时 Embedding 模型""" + if isinstance(model_info, sqlalchemy.Row): + model_info = persistence_model.EmbeddingModel(**model_info._mapping) + elif isinstance(model_info, dict): + model_info = persistence_model.EmbeddingModel(**model_info) + + requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config) + + await requester_inst.initialize() + + runtime_embedding_model = requester.RuntimeEmbeddingModel( + model_entity=model_info, + token_mgr=token.TokenManager( + name=model_info.uuid, + tokens=model_info.api_keys, + ), + requester=requester_inst, + ) + + return runtime_embedding_model + async def load_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """加载模型""" + """加载 LLM 模型""" runtime_llm_model = await self.init_runtime_llm_model(model_info) self.llm_models.append(runtime_llm_model) + async def load_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """加载 Embedding 模型""" + runtime_embedding_model = await self.init_runtime_embedding_model(model_info) + self.embedding_models.append(runtime_embedding_model) + async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated """通过名称获取模型""" for model in self.model_list: @@ -116,23 +156,44 @@ class ModelManager: return model raise ValueError(f'无法确定模型 {name} 的信息') - async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo: - """通过uuid获取模型""" + async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel: + """通过uuid获取 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == uuid: return model - raise ValueError(f'model {uuid} not found') + raise ValueError(f'LLM model {uuid} not found') + + async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel: + """通过uuid获取 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == uuid: + return model + raise ValueError(f'Embedding model {uuid} not found') async def remove_llm_model(self, model_uuid: str): - """移除模型""" + """移除 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == model_uuid: self.llm_models.remove(model) return - def get_available_requesters_info(self) -> list[dict]: + async def remove_embedding_model(self, model_uuid: str): + """移除 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == model_uuid: + self.embedding_models.remove(model) + return + + def get_available_requesters_info(self, model_type: str) -> list[dict]: """获取所有可用的请求器""" - return [component.to_plain_dict() for component in self.requester_components] + if model_type != '': + return [ + component.to_plain_dict() + for component in self.requester_components + if model_type in component.spec['support_type'] + ] + else: + return [component.to_plain_dict() for component in self.requester_components] def get_available_requester_info_by_name(self, name: str) -> dict | None: """通过名称获取请求器信息""" diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 244f4c82..9742a52c 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -20,22 +20,45 @@ class RuntimeLLMModel: token_mgr: token.TokenManager """api key管理器""" - requester: LLMAPIRequester + requester: ProviderAPIRequester """请求器实例""" def __init__( self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, - requester: LLMAPIRequester, + requester: ProviderAPIRequester, ): self.model_entity = model_entity self.token_mgr = token_mgr self.requester = requester -class LLMAPIRequester(metaclass=abc.ABCMeta): - """LLM API请求器""" +class RuntimeEmbeddingModel: + """运行时 Embedding 模型""" + + model_entity: persistence_model.EmbeddingModel + """模型数据""" + + token_mgr: token.TokenManager + """api key管理器""" + + requester: ProviderAPIRequester + """请求器实例""" + + def __init__( + self, + model_entity: persistence_model.EmbeddingModel, + token_mgr: token.TokenManager, + requester: ProviderAPIRequester, + ): + self.model_entity = model_entity + self.token_mgr = token_mgr + self.requester = requester + + +class ProviderAPIRequester(metaclass=abc.ABCMeta): + """Provider API请求器""" name: str = None @@ -74,3 +97,23 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): llm_entities.Message: 返回消息对象 """ pass + + async def invoke_embedding( + self, + query: core_entities.Query, + model: RuntimeEmbeddingModel, + input_text: str, + extra_args: dict[str, typing.Any] = {}, + ) -> list[float]: + """调用 Embedding API + + Args: + query (core_entities.Query): 请求上下文 + model (RuntimeEmbeddingModel): 使用的模型信息 + input_text (str): 输入文本 + extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. + + Returns: + list[float]: 返回的 embedding 向量 + """ + pass diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 38573854..b195ae51 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -15,7 +15,7 @@ from ...tools import entities as tools_entities from ....utils import image -class AnthropicMessages(requester.LLMAPIRequester): +class AnthropicMessages(requester.ProviderAPIRequester): """Anthropic Messages API 请求器""" client: anthropic.AsyncAnthropic diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml index c124fed9..7dbcf3ed 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./anthropicmsgs.py diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml index 24beb915..10aae30f 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./bailianchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 513086e5..98d1f13a 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -13,7 +13,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class OpenAIChatCompletions(requester.LLMAPIRequester): +class OpenAIChatCompletions(requester.ProviderAPIRequester): """OpenAI ChatCompletion API 请求器""" client: openai.AsyncClient @@ -141,3 +141,39 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def invoke_embedding( + self, + query: core_entities.Query, + model: requester.RuntimeEmbeddingModel, + input_text: str, + extra_args: dict[str, typing.Any] = {}, + ) -> list[float]: + """调用 Embedding API""" + self.client.api_key = model.token_mgr.get_token() + + args = { + 'model': model.model_entity.name, + 'input': input_text, + } + + if model.model_entity.extra_args: + args.update(model.model_entity.extra_args) + + args.update(extra_args) + + try: + resp = await self.client.embeddings.create(**args) + return resp.data[0].embedding + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + except openai.BadRequestError as e: + raise errors.RequesterError(f'请求参数错误: {e.message}') + except openai.AuthenticationError as e: + raise errors.RequesterError(f'无效的 api-key: {e.message}') + except openai.NotFoundError as e: + raise errors.RequesterError(f'请求路径错误: {e.message}') + except openai.RateLimitError as e: + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') + except openai.APIError as e: + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.yaml b/pkg/provider/modelmgr/requesters/chatcmpl.yaml index 908b30ac..ff0de6f9 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/chatcmpl.yaml @@ -22,6 +22,9 @@ spec: type: integer required: true default: 120 + support_type: + - llm + - text-embedding execution: python: path: ./chatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml index ea2c7eea..6f320e66 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./deepseekchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml index 6bfc085e..73fca19c 100644 --- a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./geminichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml index a18675a1..3a79bb49 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./giteeaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml index 893235b2..fbe57dad 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./lmstudiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index b8868f4d..4708f671 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -14,7 +14,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class ModelScopeChatCompletions(requester.LLMAPIRequester): +class ModelScopeChatCompletions(requester.ProviderAPIRequester): """ModelScope ChatCompletion API 请求器""" client: openai.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml index a641a672..a926d889 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./modelscopechatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml index f3ae73c8..52f7bcda 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./moonshotchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index 2ea4bb7d..1456515f 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -17,7 +17,7 @@ from ....core import entities as core_entities REQUESTER_NAME: str = 'ollama-chat' -class OllamaChatCompletions(requester.LLMAPIRequester): +class OllamaChatCompletions(requester.ProviderAPIRequester): """Ollama平台 ChatCompletion API请求器""" client: ollama.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/ollamachat.yaml b/pkg/provider/modelmgr/requesters/ollamachat.yaml index 01435775..f4c4bf5a 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.yaml +++ b/pkg/provider/modelmgr/requesters/ollamachat.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./ollamachat.py diff --git a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml index 2ecee6cc..ea35bce6 100644 --- a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./openrouterchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml index 9f201aa9..a5a3421c 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./ppiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml index 19b3dcc3..3872cb6f 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./siliconflowchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml index 402f04e7..c711ef2d 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./volcarkchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml index 29db4eb3..2769a402 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./xaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml index a05184ef..34539d95 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./zhipuaichatcmpl.py diff --git a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx index e21317d6..ef9c6f45 100644 --- a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx +++ b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx @@ -47,6 +47,7 @@ export const sidebarConfigList = [ zh_Hans: 'https://docs.langbot.app/zh/deploy/models/readme.html', }, }), + new SidebarChildVO({ id: 'pipelines', name: t('pipelines.title'), diff --git a/web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts b/web/src/app/home/models/component/ChooseRequesterEntity.ts similarity index 100% rename from web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts rename to web/src/app/home/models/component/ChooseRequesterEntity.ts diff --git a/web/src/app/home/models/component/ICreateEmbeddingField.ts b/web/src/app/home/models/component/ICreateEmbeddingField.ts new file mode 100644 index 00000000..ea198f3f --- /dev/null +++ b/web/src/app/home/models/component/ICreateEmbeddingField.ts @@ -0,0 +1,7 @@ +export interface ICreateEmbeddingField { + name: string; + model_provider: string; + url: string; + api_key: string; + extra_args?: string[]; +} diff --git a/web/src/app/home/models/ICreateLLMField.ts b/web/src/app/home/models/component/ICreateLLMField.ts similarity index 100% rename from web/src/app/home/models/ICreateLLMField.ts rename to web/src/app/home/models/component/ICreateLLMField.ts diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css new file mode 100644 index 00000000..9c6c54f7 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css @@ -0,0 +1,97 @@ +.cardContainer { + width: 100%; + height: 10rem; + background-color: #fff; + border-radius: 10px; + box-shadow: 0px 2px 2px 0 rgba(0, 0, 0, 0.2); + padding: 1.2rem; + cursor: pointer; +} + +.cardContainer:hover { + box-shadow: 0px 2px 8px 0 rgba(0, 0, 0, 0.1); +} + +.iconBasicInfoContainer { + width: 100%; + height: 100%; + display: flex; + flex-direction: row; + gap: 0.8rem; + user-select: none; +} + +.iconImage { + width: 3.8rem; + height: 3.8rem; + margin: 0.2rem; + border-radius: 50%; +} + +.basicInfoContainer { + display: flex; + flex-direction: column; + gap: 0.2rem; + min-width: 0; + width: 100%; +} + +.basicInfoText { + font-size: 1.4rem; + font-weight: bold; +} + +.providerContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; +} + +.providerIcon { + width: 1.2rem; + height: 1.2rem; + margin-top: 0.2rem; + color: #626262; +} + +.providerLabel { + font-size: 1.2rem; + font-weight: 600; + color: #626262; +} + +.baseURLContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; + width: calc(100% - 3rem); +} + +.baseURLIcon { + width: 1.2rem; + height: 1.2rem; + color: #626262; +} + +.baseURLText { + font-size: 1rem; + width: 100%; + color: #626262; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + max-width: 100%; +} + +.bigText { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + font-size: 1.4rem; + font-weight: bold; + max-width: 100%; +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx new file mode 100644 index 00000000..e3dfaf80 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx @@ -0,0 +1,53 @@ +import styles from './EmbeddingCard.module.css'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; + +export default function EmbeddingCard({ cardVO }: { cardVO: EmbeddingCardVO }) { + return ( +
+
+ icon + +
+ {/* 名称 */} +
+ {cardVO.name} +
+ {/* 厂商 */} +
+ + + + + {cardVO.providerLabel} + +
+ {/* baseURL */} +
+ + + + {cardVO.baseURL} +
+
+
+
+ ); +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts new file mode 100644 index 00000000..f6d960f6 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts @@ -0,0 +1,23 @@ +export interface IEmbeddingCardVO { + id: string; + iconURL: string; + name: string; + providerLabel: string; + baseURL: string; +} + +export class EmbeddingCardVO implements IEmbeddingCardVO { + id: string; + iconURL: string; + providerLabel: string; + name: string; + baseURL: string; + + constructor(props: IEmbeddingCardVO) { + this.id = props.id; + this.iconURL = props.iconURL; + this.providerLabel = props.providerLabel; + this.name = props.name; + this.baseURL = props.baseURL; + } +} diff --git a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx new file mode 100644 index 00000000..4658a22f --- /dev/null +++ b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx @@ -0,0 +1,563 @@ +import { ICreateEmbeddingField } from '@/app/home/models/component/ICreateEmbeddingField'; +import { useEffect, useState } from 'react'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { EmbeddingModel } from '@/app/infra/entities/api'; +import { UUID } from 'uuidjs'; + +import { zodResolver } from '@hookform/resolvers/zod'; +import { useForm } from 'react-hook-form'; +import { z } from 'zod'; +import { useTranslation } from 'react-i18next'; + +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogFooter, +} from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { toast } from 'sonner'; +import { i18nObj } from '@/i18n/I18nProvider'; + +const getExtraArgSchema = (t: (key: string) => string) => + z + .object({ + key: z.string().min(1, { message: t('models.keyNameRequired') }), + type: z.enum(['string', 'number', 'boolean']), + value: z.string(), + }) + .superRefine((data, ctx) => { + if (data.type === 'number' && isNaN(Number(data.value))) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeValidNumber'), + path: ['value'], + }); + } + if ( + data.type === 'boolean' && + data.value !== 'true' && + data.value !== 'false' + ) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeTrueOrFalse'), + path: ['value'], + }); + } + }); + +const getFormSchema = (t: (key: string) => string) => + z.object({ + name: z.string().min(1, { message: t('models.modelNameRequired') }), + model_provider: z + .string() + .min(1, { message: t('models.modelProviderRequired') }), + url: z.string().min(1, { message: t('models.requestURLRequired') }), + api_key: z.string().min(1, { message: t('models.apiKeyRequired') }), + extra_args: z.array(getExtraArgSchema(t)).optional(), + }); + +export default function EmbeddingForm({ + editMode, + initEmbeddingId, + onFormSubmit, + onFormCancel, + onEmbeddingDeleted, +}: { + editMode: boolean; + initEmbeddingId?: string; + onFormSubmit: () => void; + onFormCancel: () => void; + onEmbeddingDeleted: () => void; +}) { + const { t } = useTranslation(); + const formSchema = getFormSchema(t); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + name: '', + model_provider: '', + url: '', + api_key: 'sk-xxxxx', + extra_args: [], + }, + }); + + const [extraArgs, setExtraArgs] = useState< + { key: string; type: 'string' | 'number' | 'boolean'; value: string }[] + >([]); + + const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false); + const [requesterNameList, setRequesterNameList] = useState< + IChooseRequesterEntity[] + >([]); + const [requesterDefaultURLList, setRequesterDefaultURLList] = useState< + string[] + >([]); + const [modelTesting, setModelTesting] = useState(false); + + useEffect(() => { + initEmbeddingModelFormComponent().then(() => { + if (editMode && initEmbeddingId) { + getEmbeddingConfig(initEmbeddingId).then((val) => { + form.setValue('name', val.name); + form.setValue('model_provider', val.model_provider); + // setCurrentModelProvider(val.model_provider); + form.setValue('url', val.url); + form.setValue('api_key', val.api_key); + if (val.extra_args) { + const args = val.extra_args.map((arg) => { + const [key, value] = arg.split(':'); + let type: 'string' | 'number' | 'boolean' = 'string'; + if (!isNaN(Number(value))) { + type = 'number'; + } else if (value === 'true' || value === 'false') { + type = 'boolean'; + } + return { + key, + type, + value, + }; + }); + setExtraArgs(args); + form.setValue('extra_args', args); + } + }); + } else { + form.reset(); + } + }); + }, []); + + const addExtraArg = () => { + setExtraArgs([...extraArgs, { key: '', type: 'string', value: '' }]); + }; + + const updateExtraArg = ( + index: number, + field: 'key' | 'type' | 'value', + value: string, + ) => { + const newArgs = [...extraArgs]; + newArgs[index] = { + ...newArgs[index], + [field]: value, + }; + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + const removeExtraArg = (index: number) => { + const newArgs = extraArgs.filter((_, i) => i !== index); + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + async function initEmbeddingModelFormComponent() { + const requesterNameList = + await httpClient.getProviderRequesters('text-embedding'); + setRequesterNameList( + requesterNameList.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }), + ); + setRequesterDefaultURLList( + requesterNameList.requesters.map((item) => { + const config = item.spec.config; + for (let i = 0; i < config.length; i++) { + if (config[i].name == 'base_url') { + return config[i].default?.toString() || ''; + } + } + return ''; + }), + ); + } + + async function getEmbeddingConfig( + id: string, + ): Promise { + const embeddingModel = await httpClient.getProviderEmbeddingModel(id); + + const fakeExtraArgs = []; + const extraArgs = embeddingModel.model.extra_args as Record; + for (const key in extraArgs) { + fakeExtraArgs.push(`${key}:${extraArgs[key]}`); + } + return { + name: embeddingModel.model.name, + model_provider: embeddingModel.model.requester, + url: embeddingModel.model.requester_config?.base_url, + api_key: embeddingModel.model.api_keys[0], + extra_args: fakeExtraArgs, + }; + } + + function handleFormSubmit(value: z.infer) { + const extraArgsObj: Record = {}; + value.extra_args?.forEach( + (arg: { key: string; type: string; value: string }) => { + if (arg.type === 'number') { + extraArgsObj[arg.key] = Number(arg.value); + } else if (arg.type === 'boolean') { + extraArgsObj[arg.key] = arg.value === 'true'; + } else { + extraArgsObj[arg.key] = arg.value; + } + }, + ); + + const embeddingModel: EmbeddingModel = { + uuid: editMode ? initEmbeddingId || '' : UUID.generate(), + name: value.name, + description: '', + requester: value.model_provider, + requester_config: { + base_url: value.url, + timeout: 120, + }, + extra_args: extraArgsObj, + api_keys: [value.api_key], + }; + + if (editMode) { + onSaveEdit(embeddingModel).then(() => { + form.reset(); + }); + } else { + onCreateEmbedding(embeddingModel).then(() => { + form.reset(); + }); + } + } + + async function onCreateEmbedding(embeddingModel: EmbeddingModel) { + try { + await httpClient.createProviderEmbeddingModel(embeddingModel); + onFormSubmit(); + toast.success(t('models.createSuccess')); + } catch (err) { + toast.error(t('models.createError') + (err as Error).message); + } + } + + async function onSaveEdit(embeddingModel: EmbeddingModel) { + try { + await httpClient.updateProviderEmbeddingModel( + initEmbeddingId || '', + embeddingModel, + ); + onFormSubmit(); + toast.success(t('models.saveSuccess')); + } catch (err) { + toast.error(t('models.saveError') + (err as Error).message); + } + } + + function deleteModel() { + if (initEmbeddingId) { + httpClient + .deleteProviderEmbeddingModel(initEmbeddingId) + .then(() => { + onEmbeddingDeleted(); + toast.success(t('models.deleteSuccess')); + }) + .catch((err) => { + toast.error(t('models.deleteError') + err.message); + }); + } + } + + function testEmbeddingModelInForm() { + setModelTesting(true); + httpClient + .testEmbeddingModel('_', { + uuid: '', + name: form.getValues('name'), + description: '', + requester: form.getValues('model_provider'), + requester_config: { + base_url: form.getValues('url'), + timeout: 120, + }, + api_keys: [form.getValues('api_key')], + }) + .then((res) => { + console.log(res); + toast.success(t('models.testSuccess')); + }) + .catch(() => { + toast.error(t('models.testError')); + }) + .finally(() => { + setModelTesting(false); + }); + } + + return ( +
+ + + + {t('common.confirmDelete')} + + + {t('models.deleteConfirmation')} + + + + + + + + +
+ +
+ ( + + + {t('models.modelName')} + * + + + + + + + {t('models.modelProviderDescription')} + + + )} + /> + + ( + + + {t('models.modelProvider')} + * + + + + + + + )} + /> + + ( + + + {t('models.requestURL')} + * + + + + + + + )} + /> + + ( + + + {t('models.apiKey')} + * + + + + + + + )} + /> + + + {t('models.extraParameters')} +
+ {extraArgs.map((arg, index) => ( +
+ + updateExtraArg(index, 'key', e.target.value) + } + /> + + + updateExtraArg(index, 'value', e.target.value) + } + /> + +
+ ))} + +
+ + {t('embedding.extraParametersDescription')} + + +
+
+ + {editMode && ( + + )} + + + + + + + +
+ +
+ ); +} diff --git a/web/src/app/home/models/component/llm-form/LLMForm.tsx b/web/src/app/home/models/component/llm-form/LLMForm.tsx index f483f183..73cc32fe 100644 --- a/web/src/app/home/models/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/models/component/llm-form/LLMForm.tsx @@ -1,6 +1,6 @@ -import { ICreateLLMField } from '@/app/home/models/ICreateLLMField'; +import { ICreateLLMField } from '@/app/home/models/component/ICreateLLMField'; import { useEffect, useState } from 'react'; -import { IChooseRequesterEntity } from '@/app/home/models/component/llm-form/ChooseRequesterEntity'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; import { UUID } from 'uuidjs'; @@ -197,7 +197,7 @@ export default function LLMForm({ }; async function initLLMModelFormComponent() { - const requesterNameList = await httpClient.getProviderRequesters(); + const requesterNameList = await httpClient.getProviderRequesters('llm'); setRequesterNameList( requesterNameList.requesters.map((item) => { return { @@ -596,7 +596,7 @@ export default function LLMForm({ - {t('models.extraParametersDescription')} + {t('llm.extraParametersDescription')} diff --git a/web/src/app/home/models/page.tsx b/web/src/app/home/models/page.tsx index 3ccec486..2f936753 100644 --- a/web/src/app/home/models/page.tsx +++ b/web/src/app/home/models/page.tsx @@ -8,6 +8,7 @@ import LLMForm from '@/app/home/models/component/llm-form/LLMForm'; import CreateCardComponent from '@/app/infra/basic-component/create-card-component/CreateCardComponent'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; import { Dialog, DialogContent, @@ -17,6 +18,9 @@ import { import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; import { i18nObj } from '@/i18n/I18nProvider'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; +import EmbeddingCard from '@/app/home/models/component/embedding-card/EmbeddingCard'; +import EmbeddingForm from '@/app/home/models/component/embedding-form/EmbeddingForm'; export default function LLMConfigPage() { const { t } = useTranslation(); @@ -24,13 +28,21 @@ export default function LLMConfigPage() { const [modalOpen, setModalOpen] = useState(false); const [isEditForm, setIsEditForm] = useState(false); const [nowSelectedLLM, setNowSelectedLLM] = useState(null); + const [embeddingCardList, setEmbeddingCardList] = useState( + [], + ); + const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false); + const [isEditEmbeddingForm, setIsEditEmbeddingForm] = useState(false); + const [nowSelectedEmbedding, setNowSelectedEmbedding] = + useState(null); useEffect(() => { getLLMModelList(); + getEmbeddingModelList(); }, []); async function getLLMModelList() { - const requesterNameListResp = await httpClient.getProviderRequesters(); + const requesterNameListResp = await httpClient.getProviderRequesters('llm'); const requesterNameList = requesterNameListResp.requesters.map((item) => { return { label: i18nObj(item.label), @@ -74,6 +86,55 @@ export default function LLMConfigPage() { setNowSelectedLLM(null); setModalOpen(true); } + function selectEmbedding(cardVO: EmbeddingCardVO) { + setIsEditEmbeddingForm(true); + setNowSelectedEmbedding(cardVO); + setEmbeddingModalOpen(true); + } + + function handleCreateEmbeddingModelClick() { + setIsEditEmbeddingForm(false); + setNowSelectedEmbedding(null); + setEmbeddingModalOpen(true); + } + async function getEmbeddingModelList() { + const requesterNameListResp = + await httpClient.getProviderRequesters('text-embedding'); + const requesterNameList = requesterNameListResp.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }); + + httpClient + .getProviderEmbeddingModels() + .then((resp) => { + const embeddingModelList: EmbeddingCardVO[] = resp.models.map( + (model: { + uuid: string; + requester: string; + name: string; + requester_config?: { base_url?: string }; + }) => { + return new EmbeddingCardVO({ + id: model.uuid, + iconURL: httpClient.getProviderRequesterIconURL(model.requester), + name: model.name, + providerLabel: + requesterNameList.find((item) => item.value === model.requester) + ?.label || model.requester.substring(0, 10), + baseURL: model.requester_config?.base_url || '', + }); + }, + ); + setEmbeddingCardList(embeddingModelList); + }) + .catch((err) => { + console.error('get Embedding model list error', err); + toast.error(t('embedding.getModelListError') + err.message); + }); + } return (
@@ -101,26 +162,108 @@ export default function LLMConfigPage() { /> -
- - {cardList.map((cardVO) => { - return ( -
{ - selectLLM(cardVO); - }} - > - + + + + + {isEditEmbeddingForm + ? t('embedding.editModel') + : t('embedding.createModel')} + + + { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + onFormCancel={() => { + setEmbeddingModalOpen(false); + }} + onEmbeddingDeleted={() => { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + /> + + + + +
+
+ + + {t('llm.llmModels')} + + + {t('embedding.embeddingModels')} + + +
+ +
+

{t('llm.description')}

- ); - })} -
+ + +
+

+ {t('embedding.description')} +

+
+
+
+ + +
+ + {cardList.map((cardVO) => { + return ( +
{ + selectLLM(cardVO); + }} + > + +
+ ); + })} +
+
+ + +
+ + {embeddingCardList.map((cardVO) => { + return ( +
{ + selectEmbedding(cardVO); + }} + > + +
+ ); + })} +
+
+
); } diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index d86a8be0..53ddf1dd 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -55,6 +55,29 @@ export interface LLMModel { // updated_at: string; } +export interface ApiRespProviderEmbeddingModels { + models: EmbeddingModel[]; +} + +export interface ApiRespProviderEmbeddingModel { + model: EmbeddingModel; +} + +export interface EmbeddingModel { + name: string; + description: string; + uuid: string; + requester: string; + requester_config: { + base_url: string; + timeout: number; + }; + extra_args?: object; + api_keys: string[]; + // created_at: string; + // updated_at: string; +} + export interface ApiRespPipelines { pipelines: Pipeline[]; } diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 5193703b..1fd335d9 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -10,6 +10,9 @@ import { ApiRespProviderLLMModels, ApiRespProviderLLMModel, LLMModel, + ApiRespProviderEmbeddingModels, + ApiRespProviderEmbeddingModel, + EmbeddingModel, ApiRespPipelines, Pipeline, ApiRespPlatformAdapters, @@ -226,8 +229,10 @@ class HttpClient { // real api request implementation // ============ Provider API ============ - public getProviderRequesters(): Promise { - return this.get('/api/v1/provider/requesters'); + public getProviderRequesters( + model_type: string, + ): Promise { + return this.get('/api/v1/provider/requesters', { type: model_type }); } public getProviderRequester(name: string): Promise { @@ -275,6 +280,39 @@ class HttpClient { return this.post(`/api/v1/provider/models/llm/${uuid}/test`, model); } + // ============ Provider Model Embedding ============ + public getProviderEmbeddingModels(): Promise { + return this.get('/api/v1/provider/models/embedding'); + } + + public getProviderEmbeddingModel( + uuid: string, + ): Promise { + return this.get(`/api/v1/provider/models/embedding/${uuid}`); + } + + public createProviderEmbeddingModel(model: EmbeddingModel): Promise { + return this.post('/api/v1/provider/models/embedding', model); + } + + public deleteProviderEmbeddingModel(uuid: string): Promise { + return this.delete(`/api/v1/provider/models/embedding/${uuid}`); + } + + public updateProviderEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.put(`/api/v1/provider/models/embedding/${uuid}`, model); + } + + public testEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.post(`/api/v1/provider/models/embedding/${uuid}/test`, model); + } + // ============ Pipeline API ============ public getGeneralPipelineMetadata(): Promise { // as designed, this method will be deprecated, and only for developer to check the prefered config schema diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 1975a521..d0df9841 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -86,14 +86,13 @@ const enUS = { string: 'String', number: 'Number', boolean: 'Boolean', - extraParametersDescription: - 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', selectModelProvider: 'Select Model Provider', modelProviderDescription: 'Please fill in the model name provided by the supplier', selectModel: 'Select Model', testSuccess: 'Test successful', testError: 'Test failed, please check your model configuration', + llmModels: 'LLM Models', }, bots: { title: 'Bots', @@ -259,6 +258,21 @@ const enUS = { 'Password reset failed, please check your email and recovery key', backToLogin: 'Back to Login', }, + embedding: { + description: 'Manage Embedding models for text vectorization', + createModel: 'Create Embedding Model', + editModel: 'Edit Embedding Model', + getModelListError: 'Failed to get Embedding model list: ', + embeddingModels: 'Embedding', + extraParametersDescription: + 'Will be attached to the request body, such as encoding_format, dimensions, etc.', + }, + llm: { + description: 'Manage LLM models for conversation generation', + llmModels: 'LLM', + extraParametersDescription: + 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', + }, }; export default enUS; diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 2ded8236..96acc0e6 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -87,13 +87,12 @@ const zhHans = { string: '字符串', number: '数字', boolean: '布尔值', - extraParametersDescription: - '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', selectModelProvider: '选择模型供应商', modelProviderDescription: '请填写供应商向您提供的模型名称', selectModel: '请选择模型', testSuccess: '测试成功', testError: '测试失败,请检查模型配置', + llmModels: '对话模型', }, bots: { title: '机器人', @@ -251,6 +250,21 @@ const zhHans = { resetFailed: '密码重置失败,请检查邮箱和恢复密钥是否正确', backToLogin: '返回登录', }, + embedding: { + description: '管理嵌入模型,用于向量化文本', + createModel: '创建嵌入模型', + editModel: '编辑嵌入模型', + getModelListError: '获取嵌入模型列表失败:', + embeddingModels: '嵌入模型', + extraParametersDescription: + '将在请求时附加到请求体中,如 encoding_format, dimensions 等', + }, + llm: { + llmModels: '对话模型', + description: '管理 LLM 模型,用于对话消息生成', + extraParametersDescription: + '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', + }, }; export default zhHans; From 157ffdc34c96d3faf7bc6b9d22d362e55834c678 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Tue, 10 Jun 2025 08:34:53 +0800 Subject: [PATCH 02/60] feat: add knowledge page --- web/src/app/home/knowledge/page.tsx | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 web/src/app/home/knowledge/page.tsx diff --git a/web/src/app/home/knowledge/page.tsx b/web/src/app/home/knowledge/page.tsx new file mode 100644 index 00000000..9707a8ee --- /dev/null +++ b/web/src/app/home/knowledge/page.tsx @@ -0,0 +1,5 @@ +'use client'; + +export default function KnowledgePage() { + return
KnowledgePage
; +} From 348f6d9eaa0759638a37e752502f361c5e848f75 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Wed, 11 Jun 2025 20:24:42 +0800 Subject: [PATCH 03/60] feat: add api for uploading files --- pkg/api/http/controller/groups/files.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pkg/api/http/controller/groups/files.py b/pkg/api/http/controller/groups/files.py index 0a8b2210..d08cbd71 100644 --- a/pkg/api/http/controller/groups/files.py +++ b/pkg/api/http/controller/groups/files.py @@ -2,6 +2,10 @@ from __future__ import annotations import quart import mimetypes +import uuid +import asyncio + +import quart.datastructures from .. import group @@ -20,3 +24,22 @@ class FilesRouterGroup(group.RouterGroup): mime_type = 'image/jpeg' return quart.Response(image_bytes, mimetype=mime_type) + + @self.route('/documents', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> quart.Response: + request = quart.request + # get file bytes from 'file' + file = (await request.files)['file'] + assert isinstance(file, quart.datastructures.FileStorage) + + file_bytes = await asyncio.to_thread(file.stream.read) + extension = file.filename.split('.')[-1] + + file_key = str(uuid.uuid4()) + '.' + extension + # save file to storage + await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes) + return self.success( + data={ + 'file_id': file_key, + } + ) From 4bcc06c9559f08f8e84b3c95a94eb4258da22688 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Wed, 25 Jun 2025 14:32:53 +0800 Subject: [PATCH 04/60] kb --- .../http/controller/groups/knowledge_base.py | 83 +++++ pkg/core/app.py | 7 + pkg/rag/knowledge/RAG_Manager.py | 283 +++++++++++++++++ pkg/rag/knowledge/services/__init__.py | 0 pkg/rag/knowledge/services/base_service.py | 26 ++ pkg/rag/knowledge/services/chroma_manager.py | 65 ++++ pkg/rag/knowledge/services/chunker.py | 63 ++++ pkg/rag/knowledge/services/database.py | 57 ++++ pkg/rag/knowledge/services/embedder.py | 93 ++++++ .../knowledge/services/embedding_models.py | 223 ++++++++++++++ pkg/rag/knowledge/services/parser.py | 288 ++++++++++++++++++ pkg/rag/knowledge/services/retriever.py | 106 +++++++ pkg/rag/knowledge/utils/crawler.py | 0 pyproject.toml | 11 + 14 files changed, 1305 insertions(+) create mode 100644 pkg/api/http/controller/groups/knowledge_base.py create mode 100644 pkg/rag/knowledge/RAG_Manager.py create mode 100644 pkg/rag/knowledge/services/__init__.py create mode 100644 pkg/rag/knowledge/services/base_service.py create mode 100644 pkg/rag/knowledge/services/chroma_manager.py create mode 100644 pkg/rag/knowledge/services/chunker.py create mode 100644 pkg/rag/knowledge/services/database.py create mode 100644 pkg/rag/knowledge/services/embedder.py create mode 100644 pkg/rag/knowledge/services/embedding_models.py create mode 100644 pkg/rag/knowledge/services/parser.py create mode 100644 pkg/rag/knowledge/services/retriever.py create mode 100644 pkg/rag/knowledge/utils/crawler.py diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge_base.py new file mode 100644 index 00000000..c819397a --- /dev/null +++ b/pkg/api/http/controller/groups/knowledge_base.py @@ -0,0 +1,83 @@ +import quart +from __future__ import annotations +from .. import group + +@group.group_class('knowledge_base', '/api/v1/knowledge/bases') +class KnowledgeBaseRouterGroup(group.RouterGroup): + + # 定义成功方法 + def success(self, code=0, data=None, msg: str = 'ok') -> quart.Response: + return quart.jsonify({ + "code": code, + "data": data or {}, + "msg": msg + }) + + + + async def initialize(self) -> None: + rag = self.ap.knowledge_base_service.RAG_Manager() + + @self.route('', methods=['POST', 'GET']) + async def _() -> str: + + if quart.request.method == 'GET': + knowledge_bases = await rag.get_all_knowledge_bases() + bases_list = [ + { + "uuid": kb.id, + "name": kb.name, + "description": kb.description, + } for kb in knowledge_bases + ] + return self.success(code=0, + data={'bases': bases_list}, + msg='ok') + + json_data = await quart.request.json + knowledge_base_uuid = await rag.create_knowledge_base( + json_data.get('name'), + json_data.get('description') + ) + return self.success() + + + @self.route('/', methods=['GET']) + async def _(knowledge_base_uuid: str) -> str: + if quart.request.method == 'GET': + knowledge_base = await rag.get_knowledge_base_by_id(knowledge_base_uuid) + + if knowledge_base is None: + return self.http_status(404, -1, 'knowledge base not found') + + return self.success( + code=0, + data={ + "name": knowledge_base.name, + "description": knowledge_base.description, + "uuid": knowledge_base.id + }, + msg='ok' + ) + + @self.route('//files', methods=['GET']) + async def _(knowledge_base_uuid: str) -> str: + if quart.request.method == 'GET': + files = await rag.get_files_by_knowledge_base(knowledge_base_uuid) + return self.success(code=0,data=[{ + "id": file.id, + "file_name": file.file_name, + "status": file.status + } for file in files],msg='ok') + + # delete specific file in knowledge base + @self.route('//files/', methods=['DELETE']) + async def _(knowledge_base_uuid: str, file_id: str) -> str: + await rag.delete_data_by_file_id(file_id) + return self.success(code=0, msg='ok') + + # delete specific kb + @self.route('/', methods=['DELETE']) + async def _(knowledge_base_uuid: str) -> str: + await rag.delete_kb_by_id(knowledge_base_uuid) + return self.success(code=0, msg='ok') diff --git a/pkg/core/app.py b/pkg/core/app.py index 318cddcb..d8824466 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -27,6 +27,10 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities +from ...pkg.rag.knowledge import RAG_Manager + + + class Application: @@ -99,6 +103,8 @@ class Application: storage_mgr: storagemgr.StorageMgr = None + knowledge_base_service: RAG_Manager = None + # ========= HTTP Services ========= user_service: user_service.UserService = None @@ -111,6 +117,7 @@ class Application: bot_service: bot_service.BotService = None + def __init__(self): pass diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py new file mode 100644 index 00000000..e172c132 --- /dev/null +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -0,0 +1,283 @@ +# RAG_Manager class (main class, adjust imports as needed) +import logging +import os +import asyncio +from services.parser import FileParser +from services.chunker import Chunker +from services.embedder import Embedder +from services.retriever import Retriever +from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly +from services.embedding_models import EmbeddingModelFactory +from services.chroma_manager import ChromaIndexManager +from ...core import app + +class RAG_Manager: + def __init__(self, logger: logging.Logger = None): + self.logger = logger or logging.getLogger(__name__) + self.embedding_model_type = None + self.embedding_model_name = None + self.chroma_manager = None + self.parser = None + self.chunker = None + self.embedder = None + self.retriever = None + self.ap = app.Application + + async def initialize_system(self): + await asyncio.to_thread(create_db_and_tables) + + async def create_model(self, embedding_model_type: str, + embedding_model_name: str): + self.embedding_model_type = embedding_model_type + self.embedding_model_name = embedding_model_name + + try: + model = EmbeddingModelFactory.create_model( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name + ) + self.logger.info(f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}") + except Exception as e: + self.logger.critical(f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}") + raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.") + + self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}") + + self.parser = FileParser() + self.chunker = Chunker() + # Pass chroma_manager to Embedder and Retriever + self.embedder = Embedder( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager # Inject dependency + ) + self.retriever = Retriever( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager # Inject dependency + ) + + + async def create_knowledge_base(self, kb_name: str, kb_description: str): + """ + Creates a new knowledge base with the given name and description. + If a knowledge base with the same name already exists, it returns that one. + """ + try: + def _get_kb_sync(name): + session = SessionLocal() + try: + return session.query(KnowledgeBase).filter_by(name=name).first() + finally: + session.close() + + kb = await asyncio.to_thread(_get_kb_sync, kb_name) + + if not kb: + def _add_kb_sync(): + session = SessionLocal() + try: + new_kb = KnowledgeBase(name=kb_name, description=kb_description) + session.add(new_kb) + session.commit() + session.refresh(new_kb) + return new_kb + finally: + session.close() + kb = await asyncio.to_thread(_add_kb_sync) + except Exception as e: + self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) + raise + except Exception as e: + self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) + raise + + async def get_all_knowledge_bases(self): + """ + Retrieves all knowledge bases from the database. + """ + try: + def _get_all_kbs_sync(): + session = SessionLocal() + try: + return session.query(KnowledgeBase).all() + finally: + session.close() + + kbs = await asyncio.to_thread(_get_all_kbs_sync) + return kbs + except Exception as e: + self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True) + return [] + + async def get_knowledge_base_by_id(self, kb_id: int): + """ + Retrieves a knowledge base by its ID. + """ + try: + def _get_kb_sync(kb_id): + session = SessionLocal() + try: + return session.query(KnowledgeBase).filter_by(id=kb_id).first() + finally: + session.close() + + kb = await asyncio.to_thread(_get_kb_sync, kb_id) + return kb + except Exception as e: + self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + return None + + async def get_files_by_knowledge_base(self, kb_id: int): + try: + def _get_files_sync(kb_id): + session = SessionLocal() + try: + return session.query(File).filter_by(kb_id=kb_id).all() + finally: + session.close() + + files = await asyncio.to_thread(_get_files_sync, kb_id) + return files + except Exception as e: + self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True) + return [] + + + async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"): + self.logger.info(f"Starting data storage process for file: {file_path}") + try: + def _get_kb_sync(name): + session = SessionLocal() + try: + return session.query(KnowledgeBase).filter_by(name=name).first() + finally: + session.close() + + kb = await asyncio.to_thread(_get_kb_sync, kb_name) + + if not kb: + self.logger.info(f"Knowledge Base '{kb_name}' not found. Creating a new one.") + def _add_kb_sync(): + session = SessionLocal() + try: + new_kb = KnowledgeBase(name=kb_name, description=kb_description) + session.add(new_kb) + session.commit() + session.refresh(new_kb) + return new_kb + finally: + session.close() + kb = await asyncio.to_thread(_add_kb_sync) + self.logger.info(f"Created Knowledge Base: {kb.name} (ID: {kb.id})") + + def _add_file_sync(kb_id, file_name, path, file_type): + session = SessionLocal() + try: + file = File(kb_id=kb_id, file_name=file_name, path=path, file_type=file_type) + session.add(file) + session.commit() + session.refresh(file) + return file + finally: + session.close() + + file_obj = await asyncio.to_thread(_add_file_sync, kb.id, os.path.basename(file_path), file_path, file_type) + self.logger.info(f"Added file entry: {file_obj.file_name} (ID: {file_obj.id})") + + text = await self.parser.parse(file_path) + if not text: + self.logger.warning(f"File {file_path} parsed to empty content. Skipping chunking and embedding.") + # You might want to delete the file_obj from the DB here if it's empty. + session = SessionLocal() + try: + session.delete(file_obj) + session.commit() + except Exception as del_e: + self.logger.error(f"Failed to delete empty file_obj {file_obj.id}: {del_e}") + finally: + session.close() + return + + chunks_texts = await self.chunker.chunk(text) + self.logger.info(f"Chunked into {len(chunks_texts)} pieces.") + + # embed_and_store now handles both DB chunk saving and Chroma embedding + await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) + + self.logger.info(f"Data storage process completed for file: {file_path}") + + except Exception as e: + self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True) + # Consider cleaning up partially stored data if an error occurs. + return + + async def retrieve_data(self, query: str): + self.logger.info(f"Starting data retrieval process for query: '{query}'") + try: + retrieved_chunks = await self.retriever.retrieve(query) + self.logger.info(f"Successfully retrieved {len(retrieved_chunks)} chunks for query.") + return retrieved_chunks + except Exception as e: + self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) + return [] + + async def delete_data_by_file_id(self, file_id: int): + """ + Deletes data associated with a specific file_id from both the relational DB and Chroma. + """ + self.logger.info(f"Starting data deletion process for file_id: {file_id}") + session = SessionLocal() + try: + # 1. Delete from Chroma + await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) + + # 2. Delete chunks from relational DB + chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() + for chunk in chunks_to_delete: + session.delete(chunk) + self.logger.info(f"Deleted {len(chunks_to_delete)} chunks from relational DB for file_id: {file_id}.") + + # 3. Delete file entry from relational DB + file_to_delete = session.query(File).filter_by(id=file_id).first() + if file_to_delete: + session.delete(file_to_delete) + self.logger.info(f"Deleted file entry {file_id} from relational DB.") + else: + self.logger.warning(f"File entry {file_id} not found in relational DB.") + + session.commit() + self.logger.info(f"Data deletion completed for file_id: {file_id}.") + except Exception as e: + session.rollback() + self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True) + finally: + session.close() + + async def delete_kb_by_id(self, kb_id: int): + """ + Deletes a knowledge base and all associated files and chunks. + """ + self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}") + session = SessionLocal() + try: + # 1. Get the knowledge base + kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if not kb: + self.logger.warning(f"Knowledge Base with ID {kb_id} not found.") + return + + # 2. Delete all files associated with this knowledge base + files_to_delete = session.query(File).filter_by(kb_id=kb.id).all() + for file in files_to_delete: + await self.delete_data_by_file_id(file.id) + + # 3. Delete the knowledge base itself + session.delete(kb) + session.commit() + self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}") + except Exception as e: + session.rollback() + self.logger.error(f"Error deleting knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + finally: + session.close() diff --git a/pkg/rag/knowledge/services/__init__.py b/pkg/rag/knowledge/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/rag/knowledge/services/base_service.py b/pkg/rag/knowledge/services/base_service.py new file mode 100644 index 00000000..0298226a --- /dev/null +++ b/pkg/rag/knowledge/services/base_service.py @@ -0,0 +1,26 @@ +# 封装异步操作 +import asyncio +import logging +from services.database import SessionLocal # 导入 SessionLocal 工厂函数 + +class BaseService: + def __init__(self): + self.logger = logging.getLogger(self.__class__.__name__) + self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数 + + async def _run_sync(self, func, *args, **kwargs): + """ + 在单独的线程中运行同步函数。 + 如果第一个参数是 session,则在 to_thread 中获取新的 session。 + """ + # 如果函数需要数据库会话作为第一个参数,我们在这里获取它 + if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头 + session = await asyncio.to_thread(self.db_session_factory) + try: + result = await asyncio.to_thread(func, session, *args, **kwargs) + return result + finally: + session.close() + else: + # 否则,直接运行同步函数 + return await asyncio.to_thread(func, *args, **kwargs) \ No newline at end of file diff --git a/pkg/rag/knowledge/services/chroma_manager.py b/pkg/rag/knowledge/services/chroma_manager.py new file mode 100644 index 00000000..6a469168 --- /dev/null +++ b/pkg/rag/knowledge/services/chroma_manager.py @@ -0,0 +1,65 @@ +# services/chroma_manager.py +import numpy as np +import logging +from chromadb import PersistentClient +import os + +logger = logging.getLogger(__name__) + +class ChromaIndexManager: + def __init__(self, collection_name: str = "default_collection"): + self.logger = logging.getLogger(self.__class__.__name__) + chroma_data_path = "./chroma_data" + os.makedirs(chroma_data_path, exist_ok=True) + self.client = PersistentClient(path=chroma_data_path) + self._collection_name = collection_name + self._collection = None + + self.logger.info(f"ChromaIndexManager initialized. Collection name: {self._collection_name}") + + @property + def collection(self): + if self._collection is None: + self._collection = self.client.get_or_create_collection(name=self._collection_name) + self.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.") + return self._collection + + def add_embeddings_sync(self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]): + if embeddings.shape[0] != len(chunk_ids) or embeddings.shape[0] != len(file_ids) or embeddings.shape[0] != len(documents): + raise ValueError("Embedding, file_id, chunk_id, and document count mismatch.") + + chroma_ids = [f"{file_id}_{chunk_id}" for file_id, chunk_id in zip(file_ids, chunk_ids)] + metadatas = [{"file_id": fid, "chunk_id": cid} for fid, cid in zip(file_ids, chunk_ids)] + + self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") + self.collection.add( + embeddings=embeddings.tolist(), + ids=chroma_ids, + metadatas=metadatas, + documents=documents + ) + self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") + + def search_sync(self, query_embedding: np.ndarray, k: int = 5): + """ + Searches the Chroma collection for the top-k nearest neighbors. + Args: + query_embedding: A numpy array of the query embedding. + k: The number of results to return. + Returns: + A dictionary containing query results from Chroma. + """ + self.logger.debug(f"Searching Chroma collection '{self._collection_name}' with k={k}.") + results = self.collection.query( + query_embeddings=query_embedding.tolist(), + n_results=k, + # REMOVE 'ids' from the include list. It's returned by default. + include=["metadatas", "distances", "documents"] + ) + self.logger.debug(f"Chroma search returned {len(results.get('ids', [[]])[0])} results.") + return results + + def delete_by_file_id_sync(self, file_id: int): + self.logger.info(f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'.") + self.collection.delete(where={"file_id": file_id}) + self.logger.info(f"Deleted embeddings for file_id: {file_id} from Chroma.") \ No newline at end of file diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py new file mode 100644 index 00000000..f115dac4 --- /dev/null +++ b/pkg/rag/knowledge/services/chunker.py @@ -0,0 +1,63 @@ +# services/chunker.py +import logging +from typing import List +from services.base_service import BaseService # Assuming BaseService provides _run_sync + +logger = logging.getLogger(__name__) + +class Chunker(BaseService): + """ + A class for splitting long texts into smaller, overlapping chunks. + """ + def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50): + super().__init__() # Initialize BaseService + self.logger = logging.getLogger(self.__class__.__name__) + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + if self.chunk_overlap >= self.chunk_size: + self.logger.warning("Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.") + + def _split_text_sync(self, text: str) -> List[str]: + """ + Synchronously splits a long text into chunks with specified overlap. + This is a CPU-bound operation, intended to be run in a separate thread. + """ + if not text: + return [] + + # Simple whitespace-based splitting for demonstration + # For more advanced chunking, consider libraries like LangChain's text splitters + words = text.split() + chunks = [] + current_chunk = [] + + for word in words: + current_chunk.append(word) + if len(current_chunk) > self.chunk_size: + chunks.append(" ".join(current_chunk[:self.chunk_size])) + current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:] + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + # A more robust chunking strategy (e.g., using recursive character text splitter) + # from langchain.text_splitter import RecursiveCharacterTextSplitter + # text_splitter = RecursiveCharacterTextSplitter( + # chunk_size=self.chunk_size, + # chunk_overlap=self.chunk_overlap, + # length_function=len, + # is_separator_regex=False, + # ) + # return text_splitter.split_text(text) + + return [chunk for chunk in chunks if chunk.strip()] # Filter out empty chunks + + async def chunk(self, text: str) -> List[str]: + """ + Asynchronously chunks a given text into smaller pieces. + """ + self.logger.info(f"Chunking text (length: {len(text)})...") + # Run the synchronous splitting logic in a separate thread + chunks = await self._run_sync(self._split_text_sync, text) + self.logger.info(f"Text chunked into {len(chunks)} pieces.") + return chunks \ No newline at end of file diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py new file mode 100644 index 00000000..4ec21af3 --- /dev/null +++ b/pkg/rag/knowledge/services/database.py @@ -0,0 +1,57 @@ +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary +from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from datetime import datetime +import numpy as np # 用于处理从LargeBinary转换回来的embedding + +Base = declarative_base() + +class KnowledgeBase(Base): + __tablename__ = 'kb' + id = Column(Integer, primary_key=True, index=True) + name = Column(String, index=True) + description = Column(Text) + created_at = Column(DateTime, default=datetime.utcnow) + + files = relationship("File", back_populates="knowledge_base") + +class File(Base): + __tablename__ = 'file' + id = Column(Integer, primary_key=True, index=True) + kb_id = Column(Integer, ForeignKey('kb.id')) + file_name = Column(String) + path = Column(String) + created_at = Column(DateTime, default=datetime.utcnow) + file_type = Column(String) + status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误 + knowledge_base = relationship("KnowledgeBase", back_populates="files") + chunks = relationship("Chunk", back_populates="file") + +class Chunk(Base): + __tablename__ = 'chunks' + id = Column(Integer, primary_key=True, index=True) + file_id = Column(Integer, ForeignKey('file.id')) + text = Column(Text) + + file = relationship("File", back_populates="chunks") + vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one + +class Vector(Base): + __tablename__ = 'vectors' + id = Column(Integer, primary_key=True, index=True) + chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) + embedding = Column(LargeBinary) # Store embeddings as binary + + chunk = relationship("Chunk", back_populates="vector") + +# 数据库连接 +DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL +engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {}) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建所有表 (可以在应用启动时执行一次) +def create_db_and_tables(): + Base.metadata.create_all(bind=engine) + print("Database tables created/checked.") + +# 定义嵌入维度(请根据你实际使用的模型调整) +EMBEDDING_DIM = 1024 \ No newline at end of file diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py new file mode 100644 index 00000000..2b581e96 --- /dev/null +++ b/pkg/rag/knowledge/services/embedder.py @@ -0,0 +1,93 @@ +# services/embedder.py +import asyncio +import logging +import numpy as np +from typing import List +from sqlalchemy.orm import Session +from services.base_service import BaseService +from services.database import Chunk, SessionLocal +from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from services.chroma_manager import ChromaIndexManager # Import the manager + +logger = logging.getLogger(__name__) + +class Embedder(BaseService): + def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): + super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.model_type = model_type + self.model_name_key = model_name_key + self.chroma_manager = chroma_manager # Dependency Injection + + self.embedding_model: BaseEmbeddingModel = self._load_embedding_model() + + def _load_embedding_model(self) -> BaseEmbeddingModel: + self.logger.info(f"Loading embedding model: type={self.model_type}, name_key={self.model_name_key}...") + try: + model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) + self.logger.info(f"Embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}") + return model + except Exception as e: + self.logger.error(f"Failed to load embedding model '{self.model_name_key}': {e}") + raise + + def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]): + """ + Saves chunks to the relational database and returns the created Chunk objects. + This function assumes it's called within a context where the session + will be committed/rolled back and closed by the caller. + """ + self.logger.debug(f"Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).") + chunk_objects = [] + for text in chunks_texts: + chunk = Chunk(file_id=file_id, text=text) + session.add(chunk) + chunk_objects.append(chunk) + session.flush() # This populates the .id attribute for each new chunk object + self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.") + return chunk_objects + + async def embed_and_store(self, file_id: int, chunks: List[str]): + if not self.embedding_model: + raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.") + + self.logger.info(f"Embedding {len(chunks)} chunks for file_id: {file_id} using {self.model_name_key}...") + + session = SessionLocal() # Start a session that will live for the whole operation + chunk_objects = [] + try: + # 1. Save chunks to the relational database first to get their IDs + # We call _db_save_chunks_sync directly without _run_sync's session management + # because we manage the session here across multiple async calls. + chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks) + session.commit() # Commit chunks to make their IDs permanent and accessible + + if not chunk_objects: + self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.") + return [] + + # 2. Generate embeddings + embeddings: List[List[float]] = await self.embedding_model.embed_documents(chunks) + embeddings_np = np.array(embeddings, dtype=np.float32) + + if embeddings_np.shape[1] != self.embedding_model.embedding_dimension: + self.logger.error(f"Mismatch in embedding dimension: Model returned {embeddings_np.shape[1]}, expected {self.embedding_model.embedding_dimension}. Aborting storage.") + raise ValueError("Embedding dimension mismatch during embedding process.") + + self.logger.info("Saving embeddings to Chroma...") + chunk_ids = [c.id for c in chunk_objects] # Now safe to access .id because session is still open and committed + file_ids_for_chroma = [file_id] * len(chunk_ids) + + await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call + self.chroma_manager.add_embeddings_sync, + file_ids_for_chroma, chunk_ids, embeddings_np, chunks # Pass original chunks texts for documents + ) + self.logger.info(f"Successfully saved {len(chunk_objects)} embeddings to Chroma.") + return chunk_objects + + except Exception as e: + session.rollback() # Rollback on any error + self.logger.error(f"Failed to process and store data for file_id {file_id}: {e}", exc_info=True) + raise # Re-raise the exception to propagate it + finally: + session.close() # Ensure the session is always closed \ No newline at end of file diff --git a/pkg/rag/knowledge/services/embedding_models.py b/pkg/rag/knowledge/services/embedding_models.py new file mode 100644 index 00000000..a6ce73ae --- /dev/null +++ b/pkg/rag/knowledge/services/embedding_models.py @@ -0,0 +1,223 @@ +# services/embedding_models.py + +import os +from typing import Dict, Any, List, Type, Optional +import logging +import aiohttp # Import aiohttp for asynchronous requests +import asyncio +from sentence_transformers import SentenceTransformer + +logger = logging.getLogger(__name__) + +# Base class for all embedding models +class BaseEmbeddingModel: + def __init__(self, model_name: str): + self.model_name = model_name + self._embedding_dimension = None + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronously embeds a list of texts.""" + raise NotImplementedError + + async def embed_query(self, text: str) -> List[float]: + """Asynchronously embeds a single query text.""" + raise NotImplementedError + + @property + def embedding_dimension(self) -> int: + """Returns the embedding dimension of the model.""" + if self._embedding_dimension is None: + raise NotImplementedError("Embedding dimension not set for this model.") + return self._embedding_dimension + +class EmbeddingModelFactory: + @staticmethod + def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel: + """ + Factory method to create an embedding model instance. + Currently only supports 'third_party_api' types. + """ + if model_name_key not in EMBEDDING_MODEL_CONFIGS: + raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.") + + config = EMBEDDING_MODEL_CONFIGS[model_name_key] + + if config['type'] == "third_party_api": + required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension'] + if not all(key in config for key in required_keys): + raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}") + + # Retrieve model_name from config if it differs from model_name_key + # Some APIs expect a specific 'model' value in the payload that might be different from the key + api_model_name = config.get('model_name', model_name_key) + + return ThirdPartyAPIEmbeddingModel( + model_name=api_model_name, # Use the model_name from config or the key + api_endpoint=config['api_endpoint'], + headers=config['headers'], + payload_template=config['payload_template'], + embedding_dimension=config['embedding_dimension'] + ) + +class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): + def __init__(self, model_name: str): + super().__init__(model_name) + try: + # SentenceTransformer is inherently synchronous, but we'll wrap its calls + # in async methods. The actual computation will still block the event loop + # if not run in a separate thread/process, but this keeps the API consistent. + self.model = SentenceTransformer(model_name) + self._embedding_dimension = self.model.get_sentence_embedding_dimension() + logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}") + except Exception as e: + logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}") + raise + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + # For CPU-bound tasks like local model inference, consider running in a thread pool + # to prevent blocking the event loop for long operations. + # For simplicity here, we'll call it directly. + return self.model.encode(texts).tolist() + + async def embed_query(self, text: str) -> List[float]: + return self.model.encode(text).tolist() + + +class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): + def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int): + super().__init__(model_name) + self.api_endpoint = api_endpoint + self.headers = headers + self.payload_template = payload_template + self._embedding_dimension = embedding_dimension + self.session = None # aiohttp client session will be initialized on first use or in a context manager + logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}") + + async def _get_session(self): + """Lazily create or return the aiohttp client session.""" + if self.session is None or self.session.closed: + self.session = aiohttp.ClientSession() + return self.session + + async def close_session(self): + """Explicitly close the aiohttp client session.""" + if self.session and not self.session.closed: + await self.session.close() + self.session = None + logger.info(f"Closed aiohttp session for model {self.model_name}") + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronously embeds a list of texts using the third-party API.""" + session = await self._get_session() + embeddings = [] + tasks = [] + for text in texts: + payload = self.payload_template.copy() + if 'input' in payload: + payload['input'] = text + elif 'texts' in payload: + payload['texts'] = [text] + else: + raise ValueError("Payload template does not contain expected text input key.") + + tasks.append(self._make_api_request(session, payload)) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, res in enumerate(results): + if isinstance(res, Exception): + logger.error(f"Error embedding text '{texts[i][:50]}...': {res}") + # Depending on your error handling strategy, you might: + # - Append None or an empty list + # - Re-raise the exception to stop processing + # - Log and skip, then continue + embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure + else: + embeddings.append(res) + + return embeddings + + async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]: + """Helper to make an asynchronous API request and extract embedding.""" + try: + async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response: + response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx) + api_response = await response.json() + + # Adjust this based on your API's actual response structure + if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]: + embedding = api_response["data"][0]["embedding"] + if len(embedding) != self.embedding_dimension: + logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + return embedding + elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]: + embedding = api_response["embeddings"][0] + if len(embedding) != self.embedding_dimension: + logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + return embedding + else: + raise ValueError(f"Unexpected API response structure: {api_response}") + + except aiohttp.ClientError as e: + raise ConnectionError(f"API request failed: {e}") from e + except ValueError as e: + raise ValueError(f"Error processing API response: {e}") from e + + + async def embed_query(self, text: str) -> List[float]: + """Asynchronously embeds a single query text.""" + results = await self.embed_documents([text]) + if results: + return results[0] + return [] # Or raise an error if embedding a query must always succeed + +# --- Embedding Model Configuration --- +EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { + "MiniLM": { # Example for a local Sentence Transformer model + "type": "sentence_transformer", + "model_name": "sentence-transformers/all-MiniLM-L6-v2" + }, + "bge-m3": { # Example for a third-party API model + "type": "third_party_api", + "model_name": "bge-m3", + "api_endpoint": "https://api.qhaigc.net/v1/embeddings", + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.getenv('rag_api_key')}" + }, + "payload_template": { + "model": "bge-m3", + "input": "" + }, + "embedding_dimension": 1024 + }, + "OpenAI-Ada-002": { + "type": "third_party_api", + "model_name": "text-embedding-ada-002", + "api_endpoint": "https://api.openai.com/v1/embeddings", + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set + }, + "payload_template": { + "model": "text-embedding-ada-002", + "input": "" # Text will be injected here + }, + "embedding_dimension": 1536 + }, + "OpenAI-Embedding-3-Small": { + "type": "third_party_api", + "model_name": "text-embedding-3-small", + "api_endpoint": "https://api.openai.com/v1/embeddings", + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" + }, + "payload_template": { + "model": "text-embedding-3-small", + "input": "", + # "dimensions": 512 # Optional: uncomment if you want a specific output dimension + }, + "embedding_dimension": 1536 # Default max dimension for text-embedding-3-small + }, +} \ No newline at end of file diff --git a/pkg/rag/knowledge/services/parser.py b/pkg/rag/knowledge/services/parser.py new file mode 100644 index 00000000..5fa7d589 --- /dev/null +++ b/pkg/rag/knowledge/services/parser.py @@ -0,0 +1,288 @@ + +import PyPDF2 +from docx import Document +import pandas as pd +import csv +import chardet +from typing import Union, List, Callable, Any +import logging +import markdown +from bs4 import BeautifulSoup +import ebooklib +from ebooklib import epub +import re +import asyncio # Import asyncio for async operations +import os + +# Configure logging +logger = logging.getLogger(__name__) + +class FileParser: + """ + A robust file parser class to extract text content from various document formats. + It supports TXT, PDF, DOCX, XLSX, CSV, Markdown, HTML, and EPUB files. + All core file reading operations are designed to be run synchronously in a thread pool + to avoid blocking the asyncio event loop. + """ + def __init__(self): + + self.logger = logging.getLogger(self.__class__.__name__) + + async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any: + """ + Runs a synchronous function in a separate thread to prevent blocking the event loop. + This is a general utility method for wrapping blocking I/O operations. + """ + try: + return await asyncio.to_thread(sync_func, *args, **kwargs) + except Exception as e: + self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}") + raise + + async def parse(self, file_path: str) -> Union[str, None]: + """ + Parses the file based on its extension and returns the extracted text content. + This is the main asynchronous entry point for parsing. + + Args: + file_path (str): The path to the file to be parsed. + + Returns: + Union[str, None]: The extracted text content as a single string, or None if parsing fails. + """ + if not file_path or not os.path.exists(file_path): + self.logger.error(f"Invalid file path provided: {file_path}") + return None + + file_extension = file_path.split('.')[-1].lower() + parser_method = getattr(self, f'_parse_{file_extension}', None) + + if parser_method is None: + self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}") + return None + + try: + # Pass file_path to the specific parser methods + return await parser_method(file_path) + except Exception as e: + self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}") + return None + + # --- Helper for reading files with encoding detection --- + async def _read_file_content(self, file_path: str, mode: str = 'r') -> Union[str, bytes]: + """ + Reads a file with automatic encoding detection, ensuring the synchronous + file read operation runs in a separate thread. + """ + def _read_sync(): + with open(file_path, 'rb') as file: + raw_data = file.read() + detected = chardet.detect(raw_data) + encoding = detected['encoding'] or 'utf-8' + + if mode == 'r': + return raw_data.decode(encoding, errors='ignore') + return raw_data # For binary mode + + return await self._run_sync(_read_sync) + + # --- Specific Parser Methods --- + + async def _parse_txt(self, file_path: str) -> str: + """Parses a TXT file and returns its content.""" + self.logger.info(f"Parsing TXT file: {file_path}") + return await self._read_file_content(file_path, mode='r') + + async def _parse_pdf(self, file_path: str) -> str: + """Parses a PDF file and returns its text content.""" + self.logger.info(f"Parsing PDF file: {file_path}") + def _parse_pdf_sync(): + text_content = [] + with open(file_path, 'rb') as file: + pdf_reader = PyPDF2.PdfReader(file) + for page in pdf_reader.pages: + text = page.extract_text() + if text: + text_content.append(text) + return '\n'.join(text_content) + return await self._run_sync(_parse_pdf_sync) + + async def _parse_docx(self, file_path: str) -> str: + """Parses a DOCX file and returns its text content.""" + self.logger.info(f"Parsing DOCX file: {file_path}") + def _parse_docx_sync(): + doc = Document(file_path) + text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()] + return '\n'.join(text_content) + return await self._run_sync(_parse_docx_sync) + + async def _parse_doc(self, file_path: str) -> str: + """Handles .doc files, explicitly stating lack of direct support.""" + self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.") + raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.") + + async def _parse_xlsx(self, file_path: str) -> str: + """Parses an XLSX file, returning text from all sheets.""" + self.logger.info(f"Parsing XLSX file: {file_path}") + def _parse_xlsx_sync(): + excel_file = pd.ExcelFile(file_path) + all_sheet_content = [] + for sheet_name in excel_file.sheet_names: + df = pd.read_excel(file_path, sheet_name=sheet_name) + sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n" + all_sheet_content.append(sheet_text) + return '\n'.join(all_sheet_content) + return await self._run_sync(_parse_xlsx_sync) + + async def _parse_csv(self, file_path: str) -> str: + """Parses a CSV file and returns its content as a string.""" + self.logger.info(f"Parsing CSV file: {file_path}") + def _parse_csv_sync(): + # pd.read_csv can often detect encoding, but explicit detection is safer + raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function + # For simplicity, we'll let pandas handle encoding internally after a raw read. + # A more robust solution might pass encoding directly to pd.read_csv after detection. + detected = chardet.detect(open(file_path, 'rb').read()) + encoding = detected['encoding'] or 'utf-8' + df = pd.read_csv(file_path, encoding=encoding) + return df.to_string(index=False) + return await self._run_sync(_parse_csv_sync) + + async def _parse_markdown(self, file_path: str) -> str: + """Parses a Markdown file, converting it to structured plain text.""" + self.logger.info(f"Parsing Markdown file: {file_path}") + def _parse_markdown_sync(): + md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function + html_content = markdown.markdown( + md_content, + extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code'] + ) + soup = BeautifulSoup(html_content, 'html.parser') + text_parts = [] + for element in soup.children: + if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: + level = int(element.name[1]) + text_parts.append('#' * level + ' ' + element.get_text().strip()) + elif element.name == 'p': + text = element.get_text().strip() + if text: + text_parts.append(text) + elif element.name in ['ul', 'ol']: + for li in element.find_all('li'): + text_parts.append(f"* {li.get_text().strip()}") + elif element.name == 'pre': + code_block = element.get_text().strip() + if code_block: + text_parts.append(f"```\n{code_block}\n```") + elif element.name == 'table': + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + if table_str: + text_parts.append(table_str) + elif element.name: + text = element.get_text(separator=' ', strip=True) + if text: + text_parts.append(text) + cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) + return cleaned_text.strip() + return await self._run_sync(_parse_markdown_sync) + + async def _parse_html(self, file_path: str) -> str: + """Parses an HTML file, extracting structured plain text.""" + self.logger.info(f"Parsing HTML file: {file_path}") + def _parse_html_sync(): + html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function + soup = BeautifulSoup(html_content, 'html.parser') + for script_or_style in soup(["script", "style"]): + script_or_style.decompose() + text_parts = [] + for element in soup.body.children if soup.body else soup.children: + if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: + level = int(element.name[1]) + text_parts.append('#' * level + ' ' + element.get_text().strip()) + elif element.name == 'p': + text = element.get_text().strip() + if text: + text_parts.append(text) + elif element.name in ['ul', 'ol']: + for li in element.find_all('li'): + text = li.get_text().strip() + if text: + text_parts.append(f"* {text}") + elif element.name == 'table': + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + if table_str: + text_parts.append(table_str) + elif element.name: + text = element.get_text(separator=' ', strip=True) + if text: + text_parts.append(text) + cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) + return cleaned_text.strip() + return await self._run_sync(_parse_html_sync) + + async def _parse_epub(self, file_path: str) -> str: + """Parses an EPUB file, extracting metadata and content.""" + self.logger.info(f"Parsing EPUB file: {file_path}") + def _parse_epub_sync(): + book = epub.read_epub(file_path) + text_content = [] + title_meta = book.get_metadata('DC', 'title') + if title_meta: + text_content.append(f"Title: {title_meta[0][0]}") + creator_meta = book.get_metadata('DC', 'creator') + if creator_meta: + text_content.append(f"Author: {creator_meta[0][0]}") + date_meta = book.get_metadata('DC', 'date') + if date_meta: + text_content.append(f"Publish Date: {date_meta[0][0]}") + toc = book.get_toc() + if toc: + text_content.append("\n--- Table of Contents ---") + self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper + text_content.append("--- End of Table of Contents ---\n") + for item in book.get_items(): + if item.get_type() == ebooklib.ITEM_DOCUMENT: + html_content = item.get_content().decode('utf-8', errors='ignore') + soup = BeautifulSoup(html_content, 'html.parser') + for junk in soup(["script", "style", "nav", "header", "footer"]): + junk.decompose() + text = soup.get_text(separator='\n', strip=True) + text = re.sub(r'\n\s*\n', '\n\n', text) + if text: + text_content.append(text) + return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip() + return await self._run_sync(_parse_epub_sync) + + def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int): + """Recursively adds TOC items to text_content (synchronous helper).""" + indent = ' ' * level + for item in toc_list: + if isinstance(item, tuple): + chapter, subchapters = item + text_content.append(f"{indent}- {chapter.title}") + self._add_toc_items_sync(subchapters, text_content, level + 1) + else: + text_content.append(f"{indent}- {item.title}") + + def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str: + """Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous).""" + headers = [th.get_text().strip() for th in table_element.find_all('th')] + rows = [] + for tr in table_element.find_all('tr'): + cells = [td.get_text().strip() for td in tr.find_all('td')] + if cells: + rows.append(cells) + + if not headers and not rows: + return "" + + table_lines = [] + if headers: + table_lines.append(' | '.join(headers)) + table_lines.append(' | '.join(['---'] * len(headers))) + + for row_cells in rows: + padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells + table_lines.append(' | '.join(padded_cells)) + + return '\n'.join(table_lines) \ No newline at end of file diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py new file mode 100644 index 00000000..6da1c5d8 --- /dev/null +++ b/pkg/rag/knowledge/services/retriever.py @@ -0,0 +1,106 @@ +# services/retriever.py +import asyncio +import logging +import numpy as np # Make sure numpy is imported +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from services.base_service import BaseService +from services.database import Chunk, SessionLocal +from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from services.chroma_manager import ChromaIndexManager + +logger = logging.getLogger(__name__) + +class Retriever(BaseService): + def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): + super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.model_type = model_type + self.model_name_key = model_name_key + self.chroma_manager = chroma_manager + + self.embedding_model: BaseEmbeddingModel = self._load_embedding_model() + + def _load_embedding_model(self) -> BaseEmbeddingModel: + self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...") + try: + model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) + self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}") + return model + except Exception as e: + self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}") + raise + + async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: + if not self.embedding_model: + raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.") + + self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}") + + query_embedding: List[float] = await self.embedding_model.embed_query(query) + query_embedding_np = np.array([query_embedding], dtype=np.float32) + + chroma_results = await self._run_sync( + self.chroma_manager.search_sync, + query_embedding_np, k + ) + + # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' + matched_chroma_ids = chroma_results.get("ids", [[]])[0] + distances = chroma_results.get("distances", [[]])[0] + chroma_metadatas = chroma_results.get("metadatas", [[]])[0] + chroma_documents = chroma_results.get("documents", [[]])[0] + + if not matched_chroma_ids: + self.logger.info("No relevant chunks found in Chroma.") + return [] + + db_chunk_ids = [] + for metadata in chroma_metadatas: + if "chunk_id" in metadata: + db_chunk_ids.append(metadata["chunk_id"]) + else: + self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.") + + if not db_chunk_ids: + self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.") + return [] + + self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...") + chunks_from_db = await self._run_sync( + lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync + db_chunk_ids + ) + + chunk_map = {chunk.id: chunk for chunk in chunks_from_db} + results_list: List[Dict[str, Any]] = [] + + for i, chroma_id in enumerate(matched_chroma_ids): + try: + # Ensure original_chunk_id is int for DB lookup + original_chunk_id = int(chroma_id.split('_')[-1]) + except (ValueError, IndexError): + self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.") + continue + + chunk_text_from_chroma = chroma_documents[i] + distance = float(distances[i]) + file_id_from_chroma = chroma_metadatas[i].get("file_id") + + chunk_from_db = chunk_map.get(original_chunk_id) + + results_list.append({ + "chunk_id": original_chunk_id, + "text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma, + "distance": distance, + "file_id": file_id_from_chroma + }) + + self.logger.info(f"Retrieved {len(results_list)} chunks.") + return results_list + + def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]: + self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).") + chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all() + session.close() + return chunks \ No newline at end of file diff --git a/pkg/rag/knowledge/utils/crawler.py b/pkg/rag/knowledge/utils/crawler.py new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index 5e85bfb0..27a03a92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,17 @@ dependencies = [ "ruff>=0.11.9", "pre-commit>=4.2.0", "uv>=0.7.11", + "PyPDF2>=3.0.1", + "python-docx>=1.1.0", + "pandas>=2.2.2", + "chardet>=5.2.0", + "markdown>=3.6", + "beautifulsoup4>=4.12.3", + "ebooklib>=0.18", + "html2text>=2024.2.26", + "langchain>=0.2.0", + "chromadb>=0.4.24", + "sentence-transformers>=2.6.1", ] keywords = [ "bot", From c4671fbf1c046c83f8478bebb95af8865f9f6cf8 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Thu, 26 Jun 2025 14:09:26 +0800 Subject: [PATCH 05/60] delete ap --- pkg/rag/knowledge/RAG_Manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index e172c132..d85699d3 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -21,7 +21,6 @@ class RAG_Manager: self.chunker = None self.embedder = None self.retriever = None - self.ap = app.Application async def initialize_system(self): await asyncio.to_thread(create_db_and_tables) From 34fe8b324d0b5e1412641ef3b6029372c2327a90 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Thu, 3 Jul 2025 23:28:47 +0800 Subject: [PATCH 06/60] feat: add functions --- .../http/controller/groups/knowledge_base.py | 27 ++++++++-------- .../controller/groups/pipelines/pipelines.py | 2 +- pkg/core/app.py | 10 +++--- pkg/core/entities.py | 2 +- pkg/core/stages/build_app.py | 7 ++++ pkg/entity/persistence/vector.py | 14 ++++++++ pkg/rag/knowledge/RAG_Manager.py | 32 +++++++++++-------- pkg/rag/knowledge/services/base_service.py | 8 ++--- pkg/rag/knowledge/services/chroma_manager.py | 4 +-- pkg/rag/knowledge/services/chunker.py | 2 +- pkg/rag/knowledge/services/embedder.py | 8 ++--- pkg/rag/knowledge/services/retriever.py | 8 ++--- 12 files changed, 75 insertions(+), 49 deletions(-) create mode 100644 pkg/entity/persistence/vector.py diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge_base.py index c819397a..f9aa09e0 100644 --- a/pkg/api/http/controller/groups/knowledge_base.py +++ b/pkg/api/http/controller/groups/knowledge_base.py @@ -1,5 +1,4 @@ import quart -from __future__ import annotations from .. import group @group.group_class('knowledge_base', '/api/v1/knowledge/bases') @@ -16,13 +15,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): async def initialize(self) -> None: - rag = self.ap.knowledge_base_service.RAG_Manager() + @self.route('', methods=['POST', 'GET']) async def _() -> str: if quart.request.method == 'GET': - knowledge_bases = await rag.get_all_knowledge_bases() + knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases() bases_list = [ { "uuid": kb.id, @@ -35,17 +34,19 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): msg='ok') json_data = await quart.request.json - knowledge_base_uuid = await rag.create_knowledge_base( + knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( json_data.get('name'), json_data.get('description') ) - return self.success() + return self.success(code=0, + data={}, + msg='ok') - @self.route('/', methods=['GET']) + @self.route('/', methods=['GET','DELETE']) async def _(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - knowledge_base = await rag.get_knowledge_base_by_id(knowledge_base_uuid) + knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid) if knowledge_base is None: return self.http_status(404, -1, 'knowledge base not found') @@ -59,11 +60,14 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): }, msg='ok' ) + elif quart.request.method == 'DELETE': + await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid) + return self.success(code=0, msg='ok') @self.route('//files', methods=['GET']) async def _(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - files = await rag.get_files_by_knowledge_base(knowledge_base_uuid) + files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) return self.success(code=0,data=[{ "id": file.id, "file_name": file.file_name, @@ -73,11 +77,6 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): # delete specific file in knowledge base @self.route('//files/', methods=['DELETE']) async def _(knowledge_base_uuid: str, file_id: str) -> str: - await rag.delete_data_by_file_id(file_id) + await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) return self.success(code=0, msg='ok') - # delete specific kb - @self.route('/', methods=['DELETE']) - async def _(knowledge_base_uuid: str) -> str: - await rag.delete_kb_by_id(knowledge_base_uuid) - return self.success(code=0, msg='ok') diff --git a/pkg/api/http/controller/groups/pipelines/pipelines.py b/pkg/api/http/controller/groups/pipelines/pipelines.py index 96ca239a..1a8036cc 100644 --- a/pkg/api/http/controller/groups/pipelines/pipelines.py +++ b/pkg/api/http/controller/groups/pipelines/pipelines.py @@ -2,7 +2,7 @@ from __future__ import annotations import quart -from ... import group +from .. import group @group.group_class('pipelines', '/api/v1/pipelines') diff --git a/pkg/core/app.py b/pkg/core/app.py index d8824466..2e3c9500 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -27,10 +27,7 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities -from ...pkg.rag.knowledge import RAG_Manager - - - +from pkg.rag.knowledge.RAG_Manager import RAG_Manager class Application: @@ -51,6 +48,7 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None + # TODO 移动到 pipeline 里 tool_mgr: llm_tool_mgr.ToolManager = None @@ -103,7 +101,6 @@ class Application: storage_mgr: storagemgr.StorageMgr = None - knowledge_base_service: RAG_Manager = None # ========= HTTP Services ========= @@ -117,6 +114,8 @@ class Application: bot_service: bot_service.BotService = None + knowledge_base_service: RAG_Manager = None + def __init__(self): pass @@ -152,6 +151,7 @@ class Application: name='http-api-controller', scopes=[core_entities.LifecycleControlScope.APPLICATION], ) + self.task_mgr.create_task( never_ending(), name='never-ending-task', diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 4caf18ed..8dc51e5b 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -19,7 +19,7 @@ class LifecycleControlScope(enum.Enum): APPLICATION = 'application' PLATFORM = 'platform' PLUGIN = 'plugin' - PROVIDER = 'provider' + PROVIDER = 'provider' class LauncherTypes(enum.Enum): diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 482a468b..3ba468c8 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -9,6 +9,7 @@ from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.tools import toolmgr as llm_tool_mgr +from ...rag.knowledge.RAG_Manager import RAG_Manager as knowledge_base_mgr from ...platform import botmgr as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -101,6 +102,12 @@ class BuildAppStage(stage.BootingStage): embedding_models_service_inst = model_service.EmbeddingModelsService(ap) ap.embedding_models_service = embedding_models_service_inst + knowledge_base_service_inst = knowledge_base_mgr(ap) + print("knowledge_base_service_inst1", type(knowledge_base_service_inst)) + await knowledge_base_service_inst.initialize_rag_system() + ap.knowledge_base_service = knowledge_base_service_inst + print("knowledge_base_service_inst", type(ap.knowledge_base_service)) + pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst diff --git a/pkg/entity/persistence/vector.py b/pkg/entity/persistence/vector.py new file mode 100644 index 00000000..84d1dfb1 --- /dev/null +++ b/pkg/entity/persistence/vector.py @@ -0,0 +1,14 @@ +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary +from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from datetime import datetime +import numpy as np # 用于处理从LargeBinary转换回来的embedding + +Base = declarative_base() + +class Vector(Base): + __tablename__ = 'vectors' + id = Column(Integer, primary_key=True, index=True) + chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) + embedding = Column(LargeBinary) # Store embeddings as binary + + chunk = relationship("Chunk", back_populates="vector") \ No newline at end of file diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index d85699d3..292f23ce 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -1,18 +1,24 @@ # RAG_Manager class (main class, adjust imports as needed) +from __future__ import annotations # For type hinting in Python 3.7+ import logging import os import asyncio -from services.parser import FileParser -from services.chunker import Chunker -from services.embedder import Embedder -from services.retriever import Retriever -from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly -from services.embedding_models import EmbeddingModelFactory -from services.chroma_manager import ChromaIndexManager -from ...core import app +from pkg.rag.knowledge.services.parser import FileParser +from pkg.rag.knowledge.services.chunker import Chunker +from pkg.rag.knowledge.services.embedder import Embedder +from pkg.rag.knowledge.services.retriever import Retriever +from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly +from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager +from pkg.core import app # Adjust the import path as needed + class RAG_Manager: - def __init__(self, logger: logging.Logger = None): + + ap: app.Application + + def __init__(self, ap: app.Application,logger: logging.Logger = None): + self.ap = ap self.logger = logger or logging.getLogger(__name__) self.embedding_model_type = None self.embedding_model_name = None @@ -21,11 +27,11 @@ class RAG_Manager: self.chunker = None self.embedder = None self.retriever = None - - async def initialize_system(self): + + async def initialize_rag_system(self): await asyncio.to_thread(create_db_and_tables) - async def create_model(self, embedding_model_type: str, + async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str): self.embedding_model_type = embedding_model_type self.embedding_model_name = embedding_model_name @@ -57,7 +63,7 @@ class RAG_Manager: ) - async def create_knowledge_base(self, kb_name: str, kb_description: str): + async def create_knowledge_base(self, kb_name: str, kb_description: str ,): """ Creates a new knowledge base with the given name and description. If a knowledge base with the same name already exists, it returns that one. diff --git a/pkg/rag/knowledge/services/base_service.py b/pkg/rag/knowledge/services/base_service.py index 0298226a..4ff1ce39 100644 --- a/pkg/rag/knowledge/services/base_service.py +++ b/pkg/rag/knowledge/services/base_service.py @@ -1,20 +1,20 @@ # 封装异步操作 import asyncio import logging -from services.database import SessionLocal # 导入 SessionLocal 工厂函数 +from pkg.rag.knowledge.services.database import SessionLocal class BaseService: def __init__(self): self.logger = logging.getLogger(self.__class__.__name__) - self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数 + self.db_session_factory = SessionLocal async def _run_sync(self, func, *args, **kwargs): """ 在单独的线程中运行同步函数。 如果第一个参数是 session,则在 to_thread 中获取新的 session。 """ - # 如果函数需要数据库会话作为第一个参数,我们在这里获取它 - if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头 + + if getattr(func, '__name__', '').startswith('_db_'): session = await asyncio.to_thread(self.db_session_factory) try: result = await asyncio.to_thread(func, session, *args, **kwargs) diff --git a/pkg/rag/knowledge/services/chroma_manager.py b/pkg/rag/knowledge/services/chroma_manager.py index 6a469168..f8020cdb 100644 --- a/pkg/rag/knowledge/services/chroma_manager.py +++ b/pkg/rag/knowledge/services/chroma_manager.py @@ -1,4 +1,4 @@ -# services/chroma_manager.py + import numpy as np import logging from chromadb import PersistentClient @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) class ChromaIndexManager: def __init__(self, collection_name: str = "default_collection"): self.logger = logging.getLogger(self.__class__.__name__) - chroma_data_path = "./chroma_data" + chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma")) os.makedirs(chroma_data_path, exist_ok=True) self.client = PersistentClient(path=chroma_data_path) self._collection_name = collection_name diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py index f115dac4..17202a7a 100644 --- a/pkg/rag/knowledge/services/chunker.py +++ b/pkg/rag/knowledge/services/chunker.py @@ -1,7 +1,7 @@ # services/chunker.py import logging from typing import List -from services.base_service import BaseService # Assuming BaseService provides _run_sync +from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync logger = logging.getLogger(__name__) diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 2b581e96..7e20b19a 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -4,10 +4,10 @@ import logging import numpy as np from typing import List from sqlalchemy.orm import Session -from services.base_service import BaseService -from services.database import Chunk, SessionLocal -from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory -from services.chroma_manager import ChromaIndexManager # Import the manager +from pkg.rag.knowledge.services.base_service import BaseService +from pkg.rag.knowledge.services.database import Chunk, SessionLocal +from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager logger = logging.getLogger(__name__) diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py index 6da1c5d8..4da81eb1 100644 --- a/pkg/rag/knowledge/services/retriever.py +++ b/pkg/rag/knowledge/services/retriever.py @@ -4,10 +4,10 @@ import logging import numpy as np # Make sure numpy is imported from typing import List, Dict, Any from sqlalchemy.orm import Session -from services.base_service import BaseService -from services.database import Chunk, SessionLocal -from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory -from services.chroma_manager import ChromaIndexManager +from pkg.rag.knowledge.services.base_service import BaseService +from pkg.rag.knowledge.services.database import Chunk, SessionLocal +from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager logger = logging.getLogger(__name__) From 552fee9bacf0ade89cabd77ee12f886c7d8693a1 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Sat, 5 Jul 2025 17:53:11 +0800 Subject: [PATCH 07/60] fix: modify rag database --- pkg/rag/knowledge/RAG_Manager.py | 6 +++--- pkg/rag/knowledge/services/database.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index 292f23ce..6ded737a 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -62,8 +62,8 @@ class RAG_Manager: chroma_manager=self.chroma_manager # Inject dependency ) - - async def create_knowledge_base(self, kb_name: str, kb_description: str ,): + + async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5): """ Creates a new knowledge base with the given name and description. If a knowledge base with the same name already exists, it returns that one. @@ -82,7 +82,7 @@ class RAG_Manager: def _add_kb_sync(): session = SessionLocal() try: - new_kb = KnowledgeBase(name=kb_name, description=kb_description) + new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k) session.add(new_kb) session.commit() session.refresh(new_kb) diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py index 4ec21af3..a8c35883 100644 --- a/pkg/rag/knowledge/services/database.py +++ b/pkg/rag/knowledge/services/database.py @@ -11,7 +11,8 @@ class KnowledgeBase(Base): name = Column(String, index=True) description = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) - + embedding_model = Column(String, default="") # 默认嵌入模型 + top_k = Column(Integer, default=5) # 默认返回的top_k数量 files = relationship("File", back_populates="knowledge_base") class File(Base): From d2b93b3296b24028b5edc13d4a283b296b8ff12d Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 21 May 2025 12:42:39 +0800 Subject: [PATCH 08/60] feat: add embeddings model management (#1461) * feat: add embeddings model management backend support Co-Authored-By: Junyan Qin * feat: add embeddings model management frontend support Co-Authored-By: Junyan Qin * chore: revert HttpClient URL to production setting Co-Authored-By: Junyan Qin * refactor: integrate embeddings models into models page with tabs Co-Authored-By: Junyan Qin * perf: move files * perf: remove `s` * feat: allow requester to declare supported types in manifest * feat(embedding): delete dimension and encoding format * feat: add extra_args for embedding moels * perf: i18n ref * fix: linter err * fix: lint err * fix: linter err --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Junyan Qin --- .../http/controller/groups/provider/models.py | 55 +- .../controller/groups/provider/requesters.py | 3 +- pkg/api/http/service/model.py | 89 ++- pkg/core/app.py | 4 +- pkg/core/stages/build_app.py | 7 +- pkg/entity/persistence/model.py | 21 + pkg/provider/modelmgr/entities.py | 2 +- pkg/provider/modelmgr/modelmgr.py | 89 ++- pkg/provider/modelmgr/requester.py | 51 +- .../modelmgr/requesters/anthropicmsgs.py | 2 +- .../modelmgr/requesters/anthropicmsgs.yaml | 2 + .../modelmgr/requesters/bailianchatcmpl.yaml | 2 + pkg/provider/modelmgr/requesters/chatcmpl.py | 38 +- .../modelmgr/requesters/chatcmpl.yaml | 3 + .../modelmgr/requesters/deepseekchatcmpl.yaml | 2 + .../modelmgr/requesters/geminichatcmpl.yaml | 2 + .../modelmgr/requesters/giteeaichatcmpl.yaml | 2 + .../modelmgr/requesters/lmstudiochatcmpl.yaml | 2 + .../modelmgr/requesters/modelscopechatcmpl.py | 2 +- .../requesters/modelscopechatcmpl.yaml | 2 + .../modelmgr/requesters/moonshotchatcmpl.yaml | 2 + .../modelmgr/requesters/ollamachat.py | 2 +- .../modelmgr/requesters/ollamachat.yaml | 2 + .../requesters/openrouterchatcmpl.yaml | 2 + .../modelmgr/requesters/ppiochatcmpl.yaml | 2 + .../requesters/siliconflowchatcmpl.yaml | 2 + .../modelmgr/requesters/volcarkchatcmpl.yaml | 2 + .../modelmgr/requesters/xaichatcmpl.yaml | 2 + .../modelmgr/requesters/zhipuaichatcmpl.yaml | 2 + .../home-sidebar/sidbarConfigList.tsx | 1 + .../{llm-form => }/ChooseRequesterEntity.ts | 0 .../models/component/ICreateEmbeddingField.ts | 7 + .../models/{ => component}/ICreateLLMField.ts | 0 .../embedding-card/EmbeddingCard.module.css | 97 +++ .../embedding-card/EmbeddingCard.tsx | 53 ++ .../embedding-card/EmbeddingCardVO.ts | 23 + .../embedding-form/EmbeddingForm.tsx | 563 ++++++++++++++++++ .../models/component/llm-form/LLMForm.tsx | 8 +- web/src/app/home/models/page.tsx | 183 +++++- web/src/app/infra/entities/api/index.ts | 23 + web/src/app/infra/http/HttpClient.ts | 42 +- web/src/i18n/locales/en-US.ts | 18 +- web/src/i18n/locales/zh-Hans.ts | 18 +- 43 files changed, 1370 insertions(+), 64 deletions(-) rename web/src/app/home/models/component/{llm-form => }/ChooseRequesterEntity.ts (100%) create mode 100644 web/src/app/home/models/component/ICreateEmbeddingField.ts rename web/src/app/home/models/{ => component}/ICreateLLMField.ts (100%) create mode 100644 web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css create mode 100644 web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx create mode 100644 web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts create mode 100644 web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx diff --git a/pkg/api/http/controller/groups/provider/models.py b/pkg/api/http/controller/groups/provider/models.py index bb77986c..0de0c922 100644 --- a/pkg/api/http/controller/groups/provider/models.py +++ b/pkg/api/http/controller/groups/provider/models.py @@ -9,18 +9,18 @@ class LLMModelsRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={'models': await self.ap.model_service.get_llm_models()}) + return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()}) elif quart.request.method == 'POST': json_data = await quart.request.json - model_uuid = await self.ap.model_service.create_llm_model(json_data) + model_uuid = await self.ap.llm_model_service.create_llm_model(json_data) return self.success(data={'uuid': model_uuid}) @self.route('/', methods=['GET', 'PUT', 'DELETE']) async def _(model_uuid: str) -> str: if quart.request.method == 'GET': - model = await self.ap.model_service.get_llm_model(model_uuid) + model = await self.ap.llm_model_service.get_llm_model(model_uuid) if model is None: return self.http_status(404, -1, 'model not found') @@ -29,11 +29,11 @@ class LLMModelsRouterGroup(group.RouterGroup): elif quart.request.method == 'PUT': json_data = await quart.request.json - await self.ap.model_service.update_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.update_llm_model(model_uuid, json_data) return self.success() elif quart.request.method == 'DELETE': - await self.ap.model_service.delete_llm_model(model_uuid) + await self.ap.llm_model_service.delete_llm_model(model_uuid) return self.success() @@ -41,6 +41,49 @@ class LLMModelsRouterGroup(group.RouterGroup): async def _(model_uuid: str) -> str: json_data = await quart.request.json - await self.ap.model_service.test_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.test_llm_model(model_uuid, json_data) + + return self.success() + + +@group.group_class('models/embedding', '/api/v1/provider/models/embedding') +class EmbeddingModelsRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('', methods=['GET', 'POST']) + async def _() -> str: + if quart.request.method == 'GET': + return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + + model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data) + + return self.success(data={'uuid': model_uuid}) + + @self.route('/', methods=['GET', 'PUT', 'DELETE']) + async def _(model_uuid: str) -> str: + if quart.request.method == 'GET': + model = await self.ap.embedding_models_service.get_embedding_model(model_uuid) + + if model is None: + return self.http_status(404, -1, 'model not found') + + return self.success(data={'model': model}) + elif quart.request.method == 'PUT': + json_data = await quart.request.json + + await self.ap.embedding_models_service.update_embedding_model(model_uuid, json_data) + + return self.success() + elif quart.request.method == 'DELETE': + await self.ap.embedding_models_service.delete_embedding_model(model_uuid) + + return self.success() + + @self.route('//test', methods=['POST']) + async def _(model_uuid: str) -> str: + json_data = await quart.request.json + + await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data) return self.success() diff --git a/pkg/api/http/controller/groups/provider/requesters.py b/pkg/api/http/controller/groups/provider/requesters.py index 0f999288..af9e1540 100644 --- a/pkg/api/http/controller/groups/provider/requesters.py +++ b/pkg/api/http/controller/groups/provider/requesters.py @@ -8,7 +8,8 @@ class RequestersRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> quart.Response: - return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()}) + model_type = quart.request.args.get('type', '') + return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info(model_type)}) @self.route('/', methods=['GET']) async def _(requester_name: str) -> quart.Response: diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 74fb4e02..afeae3eb 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -10,7 +10,7 @@ from ....provider.modelmgr import requester as model_requester from ....provider import entities as llm_entities -class ModelsService: +class LLMModelsService: ap: app.Application def __init__(self, ap: app.Application) -> None: @@ -103,3 +103,90 @@ class ModelsService: funcs=[], extra_args={}, ) + + +class EmbeddingModelsService: + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_embedding_models(self) -> list[dict]: + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + + models = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models] + + async def create_embedding_model(self, model_data: dict) -> str: + model_data['uuid'] = str(uuid.uuid4()) + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data) + ) + + embedding_model = await self.get_embedding_model(model_data['uuid']) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + return model_data['uuid'] + + async def get_embedding_model(self, model_uuid: str) -> dict | None: + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + model = result.first() + + if model is None: + return None + + return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) + + async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None: + if 'uuid' in model_data: + del model_data['uuid'] + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_model.EmbeddingModel) + .where(persistence_model.EmbeddingModel.uuid == model_uuid) + .values(**model_data) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + embedding_model = await self.get_embedding_model(model_uuid) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + async def delete_embedding_model(self, model_uuid: str) -> None: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None: + runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None + + if model_uuid != '_': + for model in self.ap.model_mgr.embedding_models: + if model.model_entity.uuid == model_uuid: + runtime_embedding_model = model + break + + if runtime_embedding_model is None: + raise Exception('model not found') + + else: + runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data) + + await runtime_embedding_model.requester.invoke_embedding( + query=None, + model=runtime_embedding_model, + input_text='Hello, world!', + extra_args={}, + ) diff --git a/pkg/core/app.py b/pkg/core/app.py index 911acd3d..318cddcb 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -103,7 +103,9 @@ class Application: user_service: user_service.UserService = None - model_service: model_service.ModelsService = None + llm_model_service: model_service.LLMModelsService = None + + embedding_models_service: model_service.EmbeddingModelsService = None pipeline_service: pipeline_service.PipelineService = None diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 6ee35610..482a468b 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -95,8 +95,11 @@ class BuildAppStage(stage.BootingStage): user_service_inst = user_service.UserService(ap) ap.user_service = user_service_inst - model_service_inst = model_service.ModelsService(ap) - ap.model_service = model_service_inst + llm_model_service_inst = model_service.LLMModelsService(ap) + ap.llm_model_service = llm_model_service_inst + + embedding_models_service_inst = model_service.EmbeddingModelsService(ap) + ap.embedding_models_service = embedding_models_service_inst pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst diff --git a/pkg/entity/persistence/model.py b/pkg/entity/persistence/model.py index 9eb2ccef..418cab70 100644 --- a/pkg/entity/persistence/model.py +++ b/pkg/entity/persistence/model.py @@ -23,3 +23,24 @@ class LLMModel(Base): server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now(), ) + + +class EmbeddingModel(Base): + """Embedding 模型""" + + __tablename__ = 'embedding_models' + + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) + extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/provider/modelmgr/entities.py b/pkg/provider/modelmgr/entities.py index cf856894..7bc02a32 100644 --- a/pkg/provider/modelmgr/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel): token_mgr: token.TokenManager - requester: requester.LLMAPIRequester + requester: requester.ProviderAPIRequester tool_call_supported: typing.Optional[bool] = False diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index b15e53a9..2c92eacc 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -18,7 +18,7 @@ class ModelManager: model_list: list[entities.LLMModelInfo] # deprecated - requesters: dict[str, requester.LLMAPIRequester] # deprecated + requesters: dict[str, requester.ProviderAPIRequester] # deprecated token_mgrs: dict[str, token.TokenManager] # deprecated @@ -28,9 +28,11 @@ class ModelManager: llm_models: list[requester.RuntimeLLMModel] + embedding_models: list[requester.RuntimeEmbeddingModel] + requester_components: list[engine.Component] - requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache + requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache def __init__(self, ap: app.Application): self.ap = ap @@ -38,6 +40,7 @@ class ModelManager: self.requesters = {} self.token_mgrs = {} self.llm_models = [] + self.embedding_models = [] self.requester_components = [] self.requester_dict = {} @@ -45,7 +48,7 @@ class ModelManager: self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') # forge requester class dict - requester_dict: dict[str, type[requester.LLMAPIRequester]] = {} + requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {} for component in self.requester_components: requester_dict[component.metadata.name] = component.get_python_component_class() @@ -58,13 +61,11 @@ class ModelManager: self.ap.logger.info('Loading models from db...') self.llm_models = [] + self.embedding_models = [] # llm models result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) - llm_models = result.all() - - # load models for llm_model in llm_models: try: await self.load_llm_model(llm_model) @@ -73,11 +74,17 @@ class ModelManager: except Exception as e: self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}') + # embedding models + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + embedding_models = result.all() + for embedding_model in embedding_models: + await self.load_embedding_model(embedding_model) + async def init_runtime_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """初始化运行时模型""" + """初始化运行时 LLM 模型""" if isinstance(model_info, sqlalchemy.Row): model_info = persistence_model.LLMModel(**model_info._mapping) elif isinstance(model_info, dict): @@ -101,14 +108,47 @@ class ModelManager: return runtime_llm_model + async def init_runtime_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """初始化运行时 Embedding 模型""" + if isinstance(model_info, sqlalchemy.Row): + model_info = persistence_model.EmbeddingModel(**model_info._mapping) + elif isinstance(model_info, dict): + model_info = persistence_model.EmbeddingModel(**model_info) + + requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config) + + await requester_inst.initialize() + + runtime_embedding_model = requester.RuntimeEmbeddingModel( + model_entity=model_info, + token_mgr=token.TokenManager( + name=model_info.uuid, + tokens=model_info.api_keys, + ), + requester=requester_inst, + ) + + return runtime_embedding_model + async def load_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """加载模型""" + """加载 LLM 模型""" runtime_llm_model = await self.init_runtime_llm_model(model_info) self.llm_models.append(runtime_llm_model) + async def load_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """加载 Embedding 模型""" + runtime_embedding_model = await self.init_runtime_embedding_model(model_info) + self.embedding_models.append(runtime_embedding_model) + async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated """通过名称获取模型""" for model in self.model_list: @@ -116,23 +156,44 @@ class ModelManager: return model raise ValueError(f'无法确定模型 {name} 的信息') - async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo: - """通过uuid获取模型""" + async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel: + """通过uuid获取 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == uuid: return model - raise ValueError(f'model {uuid} not found') + raise ValueError(f'LLM model {uuid} not found') + + async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel: + """通过uuid获取 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == uuid: + return model + raise ValueError(f'Embedding model {uuid} not found') async def remove_llm_model(self, model_uuid: str): - """移除模型""" + """移除 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == model_uuid: self.llm_models.remove(model) return - def get_available_requesters_info(self) -> list[dict]: + async def remove_embedding_model(self, model_uuid: str): + """移除 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == model_uuid: + self.embedding_models.remove(model) + return + + def get_available_requesters_info(self, model_type: str) -> list[dict]: """获取所有可用的请求器""" - return [component.to_plain_dict() for component in self.requester_components] + if model_type != '': + return [ + component.to_plain_dict() + for component in self.requester_components + if model_type in component.spec['support_type'] + ] + else: + return [component.to_plain_dict() for component in self.requester_components] def get_available_requester_info_by_name(self, name: str) -> dict | None: """通过名称获取请求器信息""" diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 244f4c82..9742a52c 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -20,22 +20,45 @@ class RuntimeLLMModel: token_mgr: token.TokenManager """api key管理器""" - requester: LLMAPIRequester + requester: ProviderAPIRequester """请求器实例""" def __init__( self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, - requester: LLMAPIRequester, + requester: ProviderAPIRequester, ): self.model_entity = model_entity self.token_mgr = token_mgr self.requester = requester -class LLMAPIRequester(metaclass=abc.ABCMeta): - """LLM API请求器""" +class RuntimeEmbeddingModel: + """运行时 Embedding 模型""" + + model_entity: persistence_model.EmbeddingModel + """模型数据""" + + token_mgr: token.TokenManager + """api key管理器""" + + requester: ProviderAPIRequester + """请求器实例""" + + def __init__( + self, + model_entity: persistence_model.EmbeddingModel, + token_mgr: token.TokenManager, + requester: ProviderAPIRequester, + ): + self.model_entity = model_entity + self.token_mgr = token_mgr + self.requester = requester + + +class ProviderAPIRequester(metaclass=abc.ABCMeta): + """Provider API请求器""" name: str = None @@ -74,3 +97,23 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): llm_entities.Message: 返回消息对象 """ pass + + async def invoke_embedding( + self, + query: core_entities.Query, + model: RuntimeEmbeddingModel, + input_text: str, + extra_args: dict[str, typing.Any] = {}, + ) -> list[float]: + """调用 Embedding API + + Args: + query (core_entities.Query): 请求上下文 + model (RuntimeEmbeddingModel): 使用的模型信息 + input_text (str): 输入文本 + extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. + + Returns: + list[float]: 返回的 embedding 向量 + """ + pass diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 38573854..b195ae51 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -15,7 +15,7 @@ from ...tools import entities as tools_entities from ....utils import image -class AnthropicMessages(requester.LLMAPIRequester): +class AnthropicMessages(requester.ProviderAPIRequester): """Anthropic Messages API 请求器""" client: anthropic.AsyncAnthropic diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml index c124fed9..7dbcf3ed 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./anthropicmsgs.py diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml index 24beb915..10aae30f 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./bailianchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 513086e5..98d1f13a 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -13,7 +13,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class OpenAIChatCompletions(requester.LLMAPIRequester): +class OpenAIChatCompletions(requester.ProviderAPIRequester): """OpenAI ChatCompletion API 请求器""" client: openai.AsyncClient @@ -141,3 +141,39 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def invoke_embedding( + self, + query: core_entities.Query, + model: requester.RuntimeEmbeddingModel, + input_text: str, + extra_args: dict[str, typing.Any] = {}, + ) -> list[float]: + """调用 Embedding API""" + self.client.api_key = model.token_mgr.get_token() + + args = { + 'model': model.model_entity.name, + 'input': input_text, + } + + if model.model_entity.extra_args: + args.update(model.model_entity.extra_args) + + args.update(extra_args) + + try: + resp = await self.client.embeddings.create(**args) + return resp.data[0].embedding + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + except openai.BadRequestError as e: + raise errors.RequesterError(f'请求参数错误: {e.message}') + except openai.AuthenticationError as e: + raise errors.RequesterError(f'无效的 api-key: {e.message}') + except openai.NotFoundError as e: + raise errors.RequesterError(f'请求路径错误: {e.message}') + except openai.RateLimitError as e: + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') + except openai.APIError as e: + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.yaml b/pkg/provider/modelmgr/requesters/chatcmpl.yaml index 908b30ac..ff0de6f9 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/chatcmpl.yaml @@ -22,6 +22,9 @@ spec: type: integer required: true default: 120 + support_type: + - llm + - text-embedding execution: python: path: ./chatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml index ea2c7eea..6f320e66 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./deepseekchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml index 6bfc085e..73fca19c 100644 --- a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./geminichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml index a18675a1..3a79bb49 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./giteeaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml index 893235b2..fbe57dad 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./lmstudiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index b8868f4d..4708f671 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -14,7 +14,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class ModelScopeChatCompletions(requester.LLMAPIRequester): +class ModelScopeChatCompletions(requester.ProviderAPIRequester): """ModelScope ChatCompletion API 请求器""" client: openai.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml index a641a672..a926d889 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./modelscopechatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml index f3ae73c8..52f7bcda 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./moonshotchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index 2ea4bb7d..1456515f 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -17,7 +17,7 @@ from ....core import entities as core_entities REQUESTER_NAME: str = 'ollama-chat' -class OllamaChatCompletions(requester.LLMAPIRequester): +class OllamaChatCompletions(requester.ProviderAPIRequester): """Ollama平台 ChatCompletion API请求器""" client: ollama.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/ollamachat.yaml b/pkg/provider/modelmgr/requesters/ollamachat.yaml index 01435775..f4c4bf5a 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.yaml +++ b/pkg/provider/modelmgr/requesters/ollamachat.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./ollamachat.py diff --git a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml index 2ecee6cc..ea35bce6 100644 --- a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./openrouterchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml index 9f201aa9..a5a3421c 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./ppiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml index 19b3dcc3..3872cb6f 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./siliconflowchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml index 402f04e7..c711ef2d 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./volcarkchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml index 29db4eb3..2769a402 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./xaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml index a05184ef..34539d95 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./zhipuaichatcmpl.py diff --git a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx index e21317d6..ef9c6f45 100644 --- a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx +++ b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx @@ -47,6 +47,7 @@ export const sidebarConfigList = [ zh_Hans: 'https://docs.langbot.app/zh/deploy/models/readme.html', }, }), + new SidebarChildVO({ id: 'pipelines', name: t('pipelines.title'), diff --git a/web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts b/web/src/app/home/models/component/ChooseRequesterEntity.ts similarity index 100% rename from web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts rename to web/src/app/home/models/component/ChooseRequesterEntity.ts diff --git a/web/src/app/home/models/component/ICreateEmbeddingField.ts b/web/src/app/home/models/component/ICreateEmbeddingField.ts new file mode 100644 index 00000000..ea198f3f --- /dev/null +++ b/web/src/app/home/models/component/ICreateEmbeddingField.ts @@ -0,0 +1,7 @@ +export interface ICreateEmbeddingField { + name: string; + model_provider: string; + url: string; + api_key: string; + extra_args?: string[]; +} diff --git a/web/src/app/home/models/ICreateLLMField.ts b/web/src/app/home/models/component/ICreateLLMField.ts similarity index 100% rename from web/src/app/home/models/ICreateLLMField.ts rename to web/src/app/home/models/component/ICreateLLMField.ts diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css new file mode 100644 index 00000000..9c6c54f7 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css @@ -0,0 +1,97 @@ +.cardContainer { + width: 100%; + height: 10rem; + background-color: #fff; + border-radius: 10px; + box-shadow: 0px 2px 2px 0 rgba(0, 0, 0, 0.2); + padding: 1.2rem; + cursor: pointer; +} + +.cardContainer:hover { + box-shadow: 0px 2px 8px 0 rgba(0, 0, 0, 0.1); +} + +.iconBasicInfoContainer { + width: 100%; + height: 100%; + display: flex; + flex-direction: row; + gap: 0.8rem; + user-select: none; +} + +.iconImage { + width: 3.8rem; + height: 3.8rem; + margin: 0.2rem; + border-radius: 50%; +} + +.basicInfoContainer { + display: flex; + flex-direction: column; + gap: 0.2rem; + min-width: 0; + width: 100%; +} + +.basicInfoText { + font-size: 1.4rem; + font-weight: bold; +} + +.providerContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; +} + +.providerIcon { + width: 1.2rem; + height: 1.2rem; + margin-top: 0.2rem; + color: #626262; +} + +.providerLabel { + font-size: 1.2rem; + font-weight: 600; + color: #626262; +} + +.baseURLContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; + width: calc(100% - 3rem); +} + +.baseURLIcon { + width: 1.2rem; + height: 1.2rem; + color: #626262; +} + +.baseURLText { + font-size: 1rem; + width: 100%; + color: #626262; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + max-width: 100%; +} + +.bigText { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + font-size: 1.4rem; + font-weight: bold; + max-width: 100%; +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx new file mode 100644 index 00000000..e3dfaf80 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx @@ -0,0 +1,53 @@ +import styles from './EmbeddingCard.module.css'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; + +export default function EmbeddingCard({ cardVO }: { cardVO: EmbeddingCardVO }) { + return ( +
+
+ icon + +
+ {/* 名称 */} +
+ {cardVO.name} +
+ {/* 厂商 */} +
+ + + + + {cardVO.providerLabel} + +
+ {/* baseURL */} +
+ + + + {cardVO.baseURL} +
+
+
+
+ ); +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts new file mode 100644 index 00000000..f6d960f6 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts @@ -0,0 +1,23 @@ +export interface IEmbeddingCardVO { + id: string; + iconURL: string; + name: string; + providerLabel: string; + baseURL: string; +} + +export class EmbeddingCardVO implements IEmbeddingCardVO { + id: string; + iconURL: string; + providerLabel: string; + name: string; + baseURL: string; + + constructor(props: IEmbeddingCardVO) { + this.id = props.id; + this.iconURL = props.iconURL; + this.providerLabel = props.providerLabel; + this.name = props.name; + this.baseURL = props.baseURL; + } +} diff --git a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx new file mode 100644 index 00000000..4658a22f --- /dev/null +++ b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx @@ -0,0 +1,563 @@ +import { ICreateEmbeddingField } from '@/app/home/models/component/ICreateEmbeddingField'; +import { useEffect, useState } from 'react'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { EmbeddingModel } from '@/app/infra/entities/api'; +import { UUID } from 'uuidjs'; + +import { zodResolver } from '@hookform/resolvers/zod'; +import { useForm } from 'react-hook-form'; +import { z } from 'zod'; +import { useTranslation } from 'react-i18next'; + +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogFooter, +} from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { toast } from 'sonner'; +import { i18nObj } from '@/i18n/I18nProvider'; + +const getExtraArgSchema = (t: (key: string) => string) => + z + .object({ + key: z.string().min(1, { message: t('models.keyNameRequired') }), + type: z.enum(['string', 'number', 'boolean']), + value: z.string(), + }) + .superRefine((data, ctx) => { + if (data.type === 'number' && isNaN(Number(data.value))) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeValidNumber'), + path: ['value'], + }); + } + if ( + data.type === 'boolean' && + data.value !== 'true' && + data.value !== 'false' + ) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeTrueOrFalse'), + path: ['value'], + }); + } + }); + +const getFormSchema = (t: (key: string) => string) => + z.object({ + name: z.string().min(1, { message: t('models.modelNameRequired') }), + model_provider: z + .string() + .min(1, { message: t('models.modelProviderRequired') }), + url: z.string().min(1, { message: t('models.requestURLRequired') }), + api_key: z.string().min(1, { message: t('models.apiKeyRequired') }), + extra_args: z.array(getExtraArgSchema(t)).optional(), + }); + +export default function EmbeddingForm({ + editMode, + initEmbeddingId, + onFormSubmit, + onFormCancel, + onEmbeddingDeleted, +}: { + editMode: boolean; + initEmbeddingId?: string; + onFormSubmit: () => void; + onFormCancel: () => void; + onEmbeddingDeleted: () => void; +}) { + const { t } = useTranslation(); + const formSchema = getFormSchema(t); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + name: '', + model_provider: '', + url: '', + api_key: 'sk-xxxxx', + extra_args: [], + }, + }); + + const [extraArgs, setExtraArgs] = useState< + { key: string; type: 'string' | 'number' | 'boolean'; value: string }[] + >([]); + + const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false); + const [requesterNameList, setRequesterNameList] = useState< + IChooseRequesterEntity[] + >([]); + const [requesterDefaultURLList, setRequesterDefaultURLList] = useState< + string[] + >([]); + const [modelTesting, setModelTesting] = useState(false); + + useEffect(() => { + initEmbeddingModelFormComponent().then(() => { + if (editMode && initEmbeddingId) { + getEmbeddingConfig(initEmbeddingId).then((val) => { + form.setValue('name', val.name); + form.setValue('model_provider', val.model_provider); + // setCurrentModelProvider(val.model_provider); + form.setValue('url', val.url); + form.setValue('api_key', val.api_key); + if (val.extra_args) { + const args = val.extra_args.map((arg) => { + const [key, value] = arg.split(':'); + let type: 'string' | 'number' | 'boolean' = 'string'; + if (!isNaN(Number(value))) { + type = 'number'; + } else if (value === 'true' || value === 'false') { + type = 'boolean'; + } + return { + key, + type, + value, + }; + }); + setExtraArgs(args); + form.setValue('extra_args', args); + } + }); + } else { + form.reset(); + } + }); + }, []); + + const addExtraArg = () => { + setExtraArgs([...extraArgs, { key: '', type: 'string', value: '' }]); + }; + + const updateExtraArg = ( + index: number, + field: 'key' | 'type' | 'value', + value: string, + ) => { + const newArgs = [...extraArgs]; + newArgs[index] = { + ...newArgs[index], + [field]: value, + }; + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + const removeExtraArg = (index: number) => { + const newArgs = extraArgs.filter((_, i) => i !== index); + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + async function initEmbeddingModelFormComponent() { + const requesterNameList = + await httpClient.getProviderRequesters('text-embedding'); + setRequesterNameList( + requesterNameList.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }), + ); + setRequesterDefaultURLList( + requesterNameList.requesters.map((item) => { + const config = item.spec.config; + for (let i = 0; i < config.length; i++) { + if (config[i].name == 'base_url') { + return config[i].default?.toString() || ''; + } + } + return ''; + }), + ); + } + + async function getEmbeddingConfig( + id: string, + ): Promise { + const embeddingModel = await httpClient.getProviderEmbeddingModel(id); + + const fakeExtraArgs = []; + const extraArgs = embeddingModel.model.extra_args as Record; + for (const key in extraArgs) { + fakeExtraArgs.push(`${key}:${extraArgs[key]}`); + } + return { + name: embeddingModel.model.name, + model_provider: embeddingModel.model.requester, + url: embeddingModel.model.requester_config?.base_url, + api_key: embeddingModel.model.api_keys[0], + extra_args: fakeExtraArgs, + }; + } + + function handleFormSubmit(value: z.infer) { + const extraArgsObj: Record = {}; + value.extra_args?.forEach( + (arg: { key: string; type: string; value: string }) => { + if (arg.type === 'number') { + extraArgsObj[arg.key] = Number(arg.value); + } else if (arg.type === 'boolean') { + extraArgsObj[arg.key] = arg.value === 'true'; + } else { + extraArgsObj[arg.key] = arg.value; + } + }, + ); + + const embeddingModel: EmbeddingModel = { + uuid: editMode ? initEmbeddingId || '' : UUID.generate(), + name: value.name, + description: '', + requester: value.model_provider, + requester_config: { + base_url: value.url, + timeout: 120, + }, + extra_args: extraArgsObj, + api_keys: [value.api_key], + }; + + if (editMode) { + onSaveEdit(embeddingModel).then(() => { + form.reset(); + }); + } else { + onCreateEmbedding(embeddingModel).then(() => { + form.reset(); + }); + } + } + + async function onCreateEmbedding(embeddingModel: EmbeddingModel) { + try { + await httpClient.createProviderEmbeddingModel(embeddingModel); + onFormSubmit(); + toast.success(t('models.createSuccess')); + } catch (err) { + toast.error(t('models.createError') + (err as Error).message); + } + } + + async function onSaveEdit(embeddingModel: EmbeddingModel) { + try { + await httpClient.updateProviderEmbeddingModel( + initEmbeddingId || '', + embeddingModel, + ); + onFormSubmit(); + toast.success(t('models.saveSuccess')); + } catch (err) { + toast.error(t('models.saveError') + (err as Error).message); + } + } + + function deleteModel() { + if (initEmbeddingId) { + httpClient + .deleteProviderEmbeddingModel(initEmbeddingId) + .then(() => { + onEmbeddingDeleted(); + toast.success(t('models.deleteSuccess')); + }) + .catch((err) => { + toast.error(t('models.deleteError') + err.message); + }); + } + } + + function testEmbeddingModelInForm() { + setModelTesting(true); + httpClient + .testEmbeddingModel('_', { + uuid: '', + name: form.getValues('name'), + description: '', + requester: form.getValues('model_provider'), + requester_config: { + base_url: form.getValues('url'), + timeout: 120, + }, + api_keys: [form.getValues('api_key')], + }) + .then((res) => { + console.log(res); + toast.success(t('models.testSuccess')); + }) + .catch(() => { + toast.error(t('models.testError')); + }) + .finally(() => { + setModelTesting(false); + }); + } + + return ( +
+ + + + {t('common.confirmDelete')} + + + {t('models.deleteConfirmation')} + + + + + + + + +
+ +
+ ( + + + {t('models.modelName')} + * + + + + + + + {t('models.modelProviderDescription')} + + + )} + /> + + ( + + + {t('models.modelProvider')} + * + + + + + + + )} + /> + + ( + + + {t('models.requestURL')} + * + + + + + + + )} + /> + + ( + + + {t('models.apiKey')} + * + + + + + + + )} + /> + + + {t('models.extraParameters')} +
+ {extraArgs.map((arg, index) => ( +
+ + updateExtraArg(index, 'key', e.target.value) + } + /> + + + updateExtraArg(index, 'value', e.target.value) + } + /> + +
+ ))} + +
+ + {t('embedding.extraParametersDescription')} + + +
+
+ + {editMode && ( + + )} + + + + + + + +
+ +
+ ); +} diff --git a/web/src/app/home/models/component/llm-form/LLMForm.tsx b/web/src/app/home/models/component/llm-form/LLMForm.tsx index f483f183..73cc32fe 100644 --- a/web/src/app/home/models/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/models/component/llm-form/LLMForm.tsx @@ -1,6 +1,6 @@ -import { ICreateLLMField } from '@/app/home/models/ICreateLLMField'; +import { ICreateLLMField } from '@/app/home/models/component/ICreateLLMField'; import { useEffect, useState } from 'react'; -import { IChooseRequesterEntity } from '@/app/home/models/component/llm-form/ChooseRequesterEntity'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; import { UUID } from 'uuidjs'; @@ -197,7 +197,7 @@ export default function LLMForm({ }; async function initLLMModelFormComponent() { - const requesterNameList = await httpClient.getProviderRequesters(); + const requesterNameList = await httpClient.getProviderRequesters('llm'); setRequesterNameList( requesterNameList.requesters.map((item) => { return { @@ -596,7 +596,7 @@ export default function LLMForm({ - {t('models.extraParametersDescription')} + {t('llm.extraParametersDescription')} diff --git a/web/src/app/home/models/page.tsx b/web/src/app/home/models/page.tsx index 3ccec486..2f936753 100644 --- a/web/src/app/home/models/page.tsx +++ b/web/src/app/home/models/page.tsx @@ -8,6 +8,7 @@ import LLMForm from '@/app/home/models/component/llm-form/LLMForm'; import CreateCardComponent from '@/app/infra/basic-component/create-card-component/CreateCardComponent'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; import { Dialog, DialogContent, @@ -17,6 +18,9 @@ import { import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; import { i18nObj } from '@/i18n/I18nProvider'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; +import EmbeddingCard from '@/app/home/models/component/embedding-card/EmbeddingCard'; +import EmbeddingForm from '@/app/home/models/component/embedding-form/EmbeddingForm'; export default function LLMConfigPage() { const { t } = useTranslation(); @@ -24,13 +28,21 @@ export default function LLMConfigPage() { const [modalOpen, setModalOpen] = useState(false); const [isEditForm, setIsEditForm] = useState(false); const [nowSelectedLLM, setNowSelectedLLM] = useState(null); + const [embeddingCardList, setEmbeddingCardList] = useState( + [], + ); + const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false); + const [isEditEmbeddingForm, setIsEditEmbeddingForm] = useState(false); + const [nowSelectedEmbedding, setNowSelectedEmbedding] = + useState(null); useEffect(() => { getLLMModelList(); + getEmbeddingModelList(); }, []); async function getLLMModelList() { - const requesterNameListResp = await httpClient.getProviderRequesters(); + const requesterNameListResp = await httpClient.getProviderRequesters('llm'); const requesterNameList = requesterNameListResp.requesters.map((item) => { return { label: i18nObj(item.label), @@ -74,6 +86,55 @@ export default function LLMConfigPage() { setNowSelectedLLM(null); setModalOpen(true); } + function selectEmbedding(cardVO: EmbeddingCardVO) { + setIsEditEmbeddingForm(true); + setNowSelectedEmbedding(cardVO); + setEmbeddingModalOpen(true); + } + + function handleCreateEmbeddingModelClick() { + setIsEditEmbeddingForm(false); + setNowSelectedEmbedding(null); + setEmbeddingModalOpen(true); + } + async function getEmbeddingModelList() { + const requesterNameListResp = + await httpClient.getProviderRequesters('text-embedding'); + const requesterNameList = requesterNameListResp.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }); + + httpClient + .getProviderEmbeddingModels() + .then((resp) => { + const embeddingModelList: EmbeddingCardVO[] = resp.models.map( + (model: { + uuid: string; + requester: string; + name: string; + requester_config?: { base_url?: string }; + }) => { + return new EmbeddingCardVO({ + id: model.uuid, + iconURL: httpClient.getProviderRequesterIconURL(model.requester), + name: model.name, + providerLabel: + requesterNameList.find((item) => item.value === model.requester) + ?.label || model.requester.substring(0, 10), + baseURL: model.requester_config?.base_url || '', + }); + }, + ); + setEmbeddingCardList(embeddingModelList); + }) + .catch((err) => { + console.error('get Embedding model list error', err); + toast.error(t('embedding.getModelListError') + err.message); + }); + } return (
@@ -101,26 +162,108 @@ export default function LLMConfigPage() { /> -
- - {cardList.map((cardVO) => { - return ( -
{ - selectLLM(cardVO); - }} - > - + + + + + {isEditEmbeddingForm + ? t('embedding.editModel') + : t('embedding.createModel')} + + + { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + onFormCancel={() => { + setEmbeddingModalOpen(false); + }} + onEmbeddingDeleted={() => { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + /> + + + + +
+
+ + + {t('llm.llmModels')} + + + {t('embedding.embeddingModels')} + + +
+ +
+

{t('llm.description')}

- ); - })} -
+ + +
+

+ {t('embedding.description')} +

+
+
+
+ + +
+ + {cardList.map((cardVO) => { + return ( +
{ + selectLLM(cardVO); + }} + > + +
+ ); + })} +
+
+ + +
+ + {embeddingCardList.map((cardVO) => { + return ( +
{ + selectEmbedding(cardVO); + }} + > + +
+ ); + })} +
+
+
); } diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index d86a8be0..53ddf1dd 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -55,6 +55,29 @@ export interface LLMModel { // updated_at: string; } +export interface ApiRespProviderEmbeddingModels { + models: EmbeddingModel[]; +} + +export interface ApiRespProviderEmbeddingModel { + model: EmbeddingModel; +} + +export interface EmbeddingModel { + name: string; + description: string; + uuid: string; + requester: string; + requester_config: { + base_url: string; + timeout: number; + }; + extra_args?: object; + api_keys: string[]; + // created_at: string; + // updated_at: string; +} + export interface ApiRespPipelines { pipelines: Pipeline[]; } diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 5193703b..1fd335d9 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -10,6 +10,9 @@ import { ApiRespProviderLLMModels, ApiRespProviderLLMModel, LLMModel, + ApiRespProviderEmbeddingModels, + ApiRespProviderEmbeddingModel, + EmbeddingModel, ApiRespPipelines, Pipeline, ApiRespPlatformAdapters, @@ -226,8 +229,10 @@ class HttpClient { // real api request implementation // ============ Provider API ============ - public getProviderRequesters(): Promise { - return this.get('/api/v1/provider/requesters'); + public getProviderRequesters( + model_type: string, + ): Promise { + return this.get('/api/v1/provider/requesters', { type: model_type }); } public getProviderRequester(name: string): Promise { @@ -275,6 +280,39 @@ class HttpClient { return this.post(`/api/v1/provider/models/llm/${uuid}/test`, model); } + // ============ Provider Model Embedding ============ + public getProviderEmbeddingModels(): Promise { + return this.get('/api/v1/provider/models/embedding'); + } + + public getProviderEmbeddingModel( + uuid: string, + ): Promise { + return this.get(`/api/v1/provider/models/embedding/${uuid}`); + } + + public createProviderEmbeddingModel(model: EmbeddingModel): Promise { + return this.post('/api/v1/provider/models/embedding', model); + } + + public deleteProviderEmbeddingModel(uuid: string): Promise { + return this.delete(`/api/v1/provider/models/embedding/${uuid}`); + } + + public updateProviderEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.put(`/api/v1/provider/models/embedding/${uuid}`, model); + } + + public testEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.post(`/api/v1/provider/models/embedding/${uuid}/test`, model); + } + // ============ Pipeline API ============ public getGeneralPipelineMetadata(): Promise { // as designed, this method will be deprecated, and only for developer to check the prefered config schema diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 1975a521..d0df9841 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -86,14 +86,13 @@ const enUS = { string: 'String', number: 'Number', boolean: 'Boolean', - extraParametersDescription: - 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', selectModelProvider: 'Select Model Provider', modelProviderDescription: 'Please fill in the model name provided by the supplier', selectModel: 'Select Model', testSuccess: 'Test successful', testError: 'Test failed, please check your model configuration', + llmModels: 'LLM Models', }, bots: { title: 'Bots', @@ -259,6 +258,21 @@ const enUS = { 'Password reset failed, please check your email and recovery key', backToLogin: 'Back to Login', }, + embedding: { + description: 'Manage Embedding models for text vectorization', + createModel: 'Create Embedding Model', + editModel: 'Edit Embedding Model', + getModelListError: 'Failed to get Embedding model list: ', + embeddingModels: 'Embedding', + extraParametersDescription: + 'Will be attached to the request body, such as encoding_format, dimensions, etc.', + }, + llm: { + description: 'Manage LLM models for conversation generation', + llmModels: 'LLM', + extraParametersDescription: + 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', + }, }; export default enUS; diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 2ded8236..96acc0e6 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -87,13 +87,12 @@ const zhHans = { string: '字符串', number: '数字', boolean: '布尔值', - extraParametersDescription: - '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', selectModelProvider: '选择模型供应商', modelProviderDescription: '请填写供应商向您提供的模型名称', selectModel: '请选择模型', testSuccess: '测试成功', testError: '测试失败,请检查模型配置', + llmModels: '对话模型', }, bots: { title: '机器人', @@ -251,6 +250,21 @@ const zhHans = { resetFailed: '密码重置失败,请检查邮箱和恢复密钥是否正确', backToLogin: '返回登录', }, + embedding: { + description: '管理嵌入模型,用于向量化文本', + createModel: '创建嵌入模型', + editModel: '编辑嵌入模型', + getModelListError: '获取嵌入模型列表失败:', + embeddingModels: '嵌入模型', + extraParametersDescription: + '将在请求时附加到请求体中,如 encoding_format, dimensions 等', + }, + llm: { + llmModels: '对话模型', + description: '管理 LLM 模型,用于对话消息生成', + extraParametersDescription: + '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', + }, }; export default zhHans; From 6d8936bd741f3f0523ab0673a2cd6d98fa04d3c0 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Tue, 10 Jun 2025 08:34:53 +0800 Subject: [PATCH 09/60] feat: add knowledge page --- web/src/app/home/knowledge/page.tsx | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 web/src/app/home/knowledge/page.tsx diff --git a/web/src/app/home/knowledge/page.tsx b/web/src/app/home/knowledge/page.tsx new file mode 100644 index 00000000..9707a8ee --- /dev/null +++ b/web/src/app/home/knowledge/page.tsx @@ -0,0 +1,5 @@ +'use client'; + +export default function KnowledgePage() { + return
KnowledgePage
; +} From f36a61dbb20f4fbe554494aef559379c62e0943f Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Wed, 11 Jun 2025 20:24:42 +0800 Subject: [PATCH 10/60] feat: add api for uploading files --- pkg/api/http/controller/groups/files.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pkg/api/http/controller/groups/files.py b/pkg/api/http/controller/groups/files.py index 0a8b2210..d08cbd71 100644 --- a/pkg/api/http/controller/groups/files.py +++ b/pkg/api/http/controller/groups/files.py @@ -2,6 +2,10 @@ from __future__ import annotations import quart import mimetypes +import uuid +import asyncio + +import quart.datastructures from .. import group @@ -20,3 +24,22 @@ class FilesRouterGroup(group.RouterGroup): mime_type = 'image/jpeg' return quart.Response(image_bytes, mimetype=mime_type) + + @self.route('/documents', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> quart.Response: + request = quart.request + # get file bytes from 'file' + file = (await request.files)['file'] + assert isinstance(file, quart.datastructures.FileStorage) + + file_bytes = await asyncio.to_thread(file.stream.read) + extension = file.filename.split('.')[-1] + + file_key = str(uuid.uuid4()) + '.' + extension + # save file to storage + await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes) + return self.success( + data={ + 'file_id': file_key, + } + ) From 0733f8878f93a17ad281b6a71ea8685877bf78c7 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 27 Jun 2025 21:37:53 +0800 Subject: [PATCH 11/60] feat: add sidebar for rag and related i18n --- .../home-sidebar/sidbarConfigList.tsx | 19 +++++++++++++++++++ web/src/i18n/locales/en-US.ts | 4 ++++ web/src/i18n/locales/ja-JP.ts | 19 +++++++++++++++++++ web/src/i18n/locales/zh-Hans.ts | 4 ++++ 4 files changed, 46 insertions(+) diff --git a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx index ef9c6f45..1c3fb4bb 100644 --- a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx +++ b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx @@ -88,4 +88,23 @@ export const sidebarConfigList = [ zh_Hans: 'https://docs.langbot.app/zh/plugin/plugin-intro.html', }, }), + new SidebarChildVO({ + id: 'knowledge', + name: t('knowledge.title'), + icon: ( + + + + ), + route: '/home/knowledge', + description: t('knowledge.description'), + helpLink: { + en_US: 'https://docs.langbot.app/en/deploy/knowledge/readme.html', + zh_Hans: 'https://docs.langbot.app/zh/deploy/knowledge/readme.html', + }, + }), ]; diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index d0df9841..5596e35f 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -230,6 +230,10 @@ const enUS = { atTips: 'Mention the bot', }, }, + knowledge: { + title: 'Knowledge', + description: 'Configuring knowledge bases for improved LLM responses', + }, register: { title: 'Initialize LangBot 👋', description: 'This is your first time starting LangBot', diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index bac6f805..21b0ff7d 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -232,6 +232,10 @@ const jaJP = { atTips: 'ボットをメンション', }, }, + knowledge: { + title: '知識ベース', + description: 'LLMの応答品質を向上させるための知識ベースを設定します', + }, register: { title: 'LangBot を初期化 👋', description: 'これはLangBotの初回起動です', @@ -260,6 +264,21 @@ const jaJP = { 'パスワードのリセットに失敗しました。メールアドレスと復旧キーを確認してください', backToLogin: 'ログインに戻る', }, + embedding: { + description: 'テキストのベクトル化に使用する埋め込みモデルを管理します', + createModel: '埋め込みモデルを作成', + editModel: '埋め込みモデルを編集', + getModelListError: '埋め込みモデルリストの取得に失敗しました:', + embeddingModels: '埋め込みモデル', + extraParametersDescription: + 'リクエストボディに追加されるパラメータ(encoding_format、dimensions など)', + }, + llm: { + description: 'チャットメッセージの生成に使用するLLMモデルを管理します', + llmModels: 'LLMモデル', + extraParametersDescription: + 'リクエストボディに追加されるパラメータ(max_tokens、temperature、top_p など)', + }, }; export default jaJP; diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 96acc0e6..1bd04ca8 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -225,6 +225,10 @@ const zhHans = { atTips: '提及机器人', }, }, + knowledge: { + title: '知识库', + description: '配置可用于提升模型回复质量的知识库', + }, register: { title: '初始化 LangBot 👋', description: '这是您首次启动 LangBot', From 22ef1a399e07d676625a82b728e1661f245bc210 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 27 Jun 2025 22:18:48 +0800 Subject: [PATCH 12/60] feat: add knowledge base page --- .../home-sidebar/sidbarConfigList.tsx | 38 +++++++++---------- .../home/knowledge/knowledgeBase.module.css | 15 ++++++++ web/src/app/home/knowledge/page.tsx | 19 +++++++++- 3 files changed, 52 insertions(+), 20 deletions(-) create mode 100644 web/src/app/home/knowledge/knowledgeBase.module.css diff --git a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx index 1c3fb4bb..b3edb98a 100644 --- a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx +++ b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx @@ -68,6 +68,25 @@ export const sidebarConfigList = [ zh_Hans: 'https://docs.langbot.app/zh/deploy/pipelines/readme.html', }, }), + new SidebarChildVO({ + id: 'knowledge', + name: t('knowledge.title'), + icon: ( + + + + ), + route: '/home/knowledge', + description: t('knowledge.description'), + helpLink: { + en_US: 'https://docs.langbot.app/en/deploy/knowledge/readme.html', + zh_Hans: 'https://docs.langbot.app/zh/deploy/knowledge/readme.html', + }, + }), new SidebarChildVO({ id: 'plugins', name: t('plugins.title'), @@ -88,23 +107,4 @@ export const sidebarConfigList = [ zh_Hans: 'https://docs.langbot.app/zh/plugin/plugin-intro.html', }, }), - new SidebarChildVO({ - id: 'knowledge', - name: t('knowledge.title'), - icon: ( - - - - ), - route: '/home/knowledge', - description: t('knowledge.description'), - helpLink: { - en_US: 'https://docs.langbot.app/en/deploy/knowledge/readme.html', - zh_Hans: 'https://docs.langbot.app/zh/deploy/knowledge/readme.html', - }, - }), ]; diff --git a/web/src/app/home/knowledge/knowledgeBase.module.css b/web/src/app/home/knowledge/knowledgeBase.module.css new file mode 100644 index 00000000..e811b521 --- /dev/null +++ b/web/src/app/home/knowledge/knowledgeBase.module.css @@ -0,0 +1,15 @@ +.configPageContainer { + width: 100%; + height: 100%; +} + +.knowledgeListContainer { + width: 100%; + padding-left: 0.8rem; + padding-right: 0.8rem; + display: grid; + grid-template-columns: repeat(auto-fill, minmax(24rem, 1fr)); + gap: 2rem; + justify-items: stretch; + align-items: start; +} diff --git a/web/src/app/home/knowledge/page.tsx b/web/src/app/home/knowledge/page.tsx index 9707a8ee..7c9fd048 100644 --- a/web/src/app/home/knowledge/page.tsx +++ b/web/src/app/home/knowledge/page.tsx @@ -1,5 +1,22 @@ 'use client'; +import CreateCardComponent from '@/app/infra/basic-component/create-card-component/CreateCardComponent'; +import styles from './knowledgeBase.module.css'; + export default function KnowledgePage() { - return
KnowledgePage
; + return ( +
+
+ { + // setIsEditForm(false); + // setModalOpen(true); + }} + /> +
+
+ ); } From bbf583ddb5c98239ef1f1bf755a2101b656c3126 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 29 Jun 2025 21:00:48 +0800 Subject: [PATCH 13/60] feat: basic entities of kb --- .../components/kb-card/KBCard.module.css | 107 ++++++++++++++++++ .../knowledge/components/kb-card/KBCard.tsx | 36 ++++++ .../knowledge/components/kb-card/KBCardVO.ts | 23 ++++ web/src/app/home/knowledge/page.tsx | 22 ++++ web/src/app/infra/entities/api/index.ts | 17 +++ web/src/app/infra/http/HttpClient.ts | 16 +++ 6 files changed, 221 insertions(+) create mode 100644 web/src/app/home/knowledge/components/kb-card/KBCard.module.css create mode 100644 web/src/app/home/knowledge/components/kb-card/KBCard.tsx create mode 100644 web/src/app/home/knowledge/components/kb-card/KBCardVO.ts diff --git a/web/src/app/home/knowledge/components/kb-card/KBCard.module.css b/web/src/app/home/knowledge/components/kb-card/KBCard.module.css new file mode 100644 index 00000000..2ecbd44a --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-card/KBCard.module.css @@ -0,0 +1,107 @@ +.cardContainer { + width: 100%; + height: 10rem; + background-color: #fff; + border-radius: 10px; + box-shadow: 0px 2px 2px 0 rgba(0, 0, 0, 0.2); + padding: 1.2rem; + cursor: pointer; + display: flex; + flex-direction: row; + justify-content: space-between; + gap: 0.5rem; +} + +.cardContainer:hover { + box-shadow: 0px 2px 8px 0 rgba(0, 0, 0, 0.1); +} + +.basicInfoContainer { + width: 100%; + height: 100%; + display: flex; + flex-direction: column; + justify-content: space-between; + gap: 0.4rem; + min-width: 0; +} + +.basicInfoNameContainer { + display: flex; + flex-direction: column; + gap: 0.2rem; +} + +.basicInfoNameText { + font-size: 1.4rem; + font-weight: 500; +} + +.basicInfoDescriptionText { + font-size: 0.9rem; + font-weight: 400; + display: -webkit-box; + -webkit-line-clamp: 3; + -webkit-box-orient: vertical; + overflow: hidden; + text-overflow: ellipsis; + color: #b1b1b1; +} + +.basicInfoLastUpdatedTimeContainer { + display: flex; + flex-direction: row; + align-items: center; + gap: 0.5rem; +} + +.basicInfoUpdateTimeIcon { + width: 1.2rem; + height: 1.2rem; +} + +.basicInfoUpdateTimeText { + font-size: 1rem; + font-weight: 400; +} + +.operationContainer { + display: flex; + flex-direction: column; + align-items: flex-end; + justify-content: space-between; + gap: 0.5rem; + width: 8rem; +} + +.operationDefaultBadge { + display: flex; + flex-direction: row; + gap: 0.5rem; +} + +.operationDefaultBadgeIcon { + width: 1.2rem; + height: 1.2rem; + color: #ffcd27; +} + +.operationDefaultBadgeText { + font-size: 1rem; + font-weight: 400; + color: #ffcd27; +} + +.bigText { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + font-size: 1.4rem; + font-weight: bold; + max-width: 100%; +} + +.debugButtonIcon { + width: 1.2rem; + height: 1.2rem; +} diff --git a/web/src/app/home/knowledge/components/kb-card/KBCard.tsx b/web/src/app/home/knowledge/components/kb-card/KBCard.tsx new file mode 100644 index 00000000..5d49e738 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-card/KBCard.tsx @@ -0,0 +1,36 @@ +import { KnowledgeBaseVO } from '@/app/home/knowledge/components/kb-card/KBCardVO'; +import { useTranslation } from 'react-i18next'; +import styles from './KBCard.module.css'; + +export default function KBCard({ kbCardVO }: { kbCardVO: KnowledgeBaseVO }) { + const { t } = useTranslation(); + return ( +
+
+
+
+ {kbCardVO.name} +
+
+ {kbCardVO.description} +
+
+ +
+ + + +
+ {t('knowledge.bases.updateTime')} + {kbCardVO.lastUpdatedTimeAgo} +
+
+
+
+ ); +} diff --git a/web/src/app/home/knowledge/components/kb-card/KBCardVO.ts b/web/src/app/home/knowledge/components/kb-card/KBCardVO.ts new file mode 100644 index 00000000..bfbc2adb --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-card/KBCardVO.ts @@ -0,0 +1,23 @@ +export interface IKnowledgeBaseVO { + id: string; + name: string; + description: string; + embeddingModelUUID: string; + lastUpdatedTimeAgo: string; +} + +export class KnowledgeBaseVO implements IKnowledgeBaseVO { + id: string; + name: string; + description: string; + embeddingModelUUID: string; + lastUpdatedTimeAgo: string; + + constructor(props: IKnowledgeBaseVO) { + this.id = props.id; + this.name = props.name; + this.description = props.description; + this.embeddingModelUUID = props.embeddingModelUUID; + this.lastUpdatedTimeAgo = props.lastUpdatedTimeAgo; + } +} diff --git a/web/src/app/home/knowledge/page.tsx b/web/src/app/home/knowledge/page.tsx index 7c9fd048..7ee25eac 100644 --- a/web/src/app/home/knowledge/page.tsx +++ b/web/src/app/home/knowledge/page.tsx @@ -2,8 +2,22 @@ import CreateCardComponent from '@/app/infra/basic-component/create-card-component/CreateCardComponent'; import styles from './knowledgeBase.module.css'; +import { useTranslation } from 'react-i18next'; +import { useState } from 'react'; +import { KnowledgeBaseVO } from '@/app/home/knowledge/components/kb-card/KBCardVO'; +import KBCard from '@/app/home/knowledge/components/kb-card/KBCard'; export default function KnowledgePage() { + const { t } = useTranslation(); + const [knowledgeBaseList, setKnowledgeBaseList] = useState( + [], + ); + + const handleKBCardClick = (kbId: string) => { + // setIsEditForm(false); + // setModalOpen(true); + }; + return (
@@ -16,6 +30,14 @@ export default function KnowledgePage() { // setModalOpen(true); }} /> + + {knowledgeBaseList.map((kb) => { + return ( +
handleKBCardClick(kb.id)}> + +
+ ); + })}
); diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index 53ddf1dd..a44b1991 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -133,6 +133,23 @@ export interface Bot { updated_at?: string; } +export interface ApiRespKnowledgeBases { + bases: KnowledgeBase[]; +} + +export interface ApiRespKnowledgeBase { + base: KnowledgeBase; +} + +export interface KnowledgeBase { + uuid?: string; + name: string; + description: string; + embedding_model_uuid: string; + created_at?: string; + updated_at?: string; +} + // plugins export interface ApiRespPlugins { plugins: Plugin[]; diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 1fd335d9..5c6e0abd 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -34,6 +34,9 @@ import { AsyncTask, ApiRespWebChatMessage, ApiRespWebChatMessages, + ApiRespKnowledgeBases, + ApiRespKnowledgeBase, + KnowledgeBase, } from '@/app/infra/entities/api'; import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest'; import { GetBotLogsResponse } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse'; @@ -427,6 +430,19 @@ class HttpClient { return this.post(`/api/v1/platform/bots/${botId}/logs`, request); } + // ============ Knowledge Base API ============ + public getKnowledgeBases(): Promise { + return this.get('/api/v1/knowledge/bases'); + } + + public getKnowledgeBase(uuid: string): Promise { + return this.get(`/api/v1/knowledge/bases/${uuid}`); + } + + public createKnowledgeBase(base: KnowledgeBase): Promise<{ uuid: string }> { + return this.post('/api/v1/knowledge/bases', base); + } + // ============ Plugins API ============ public getPlugins(): Promise { return this.get('/api/v1/plugins'); From 0e5c9e19e16f30cbbfb0ab922d431328e95f9cb1 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 5 Jul 2025 21:03:14 +0800 Subject: [PATCH 14/60] feat: complete support_type for 302ai and compshare requester --- pkg/provider/modelmgr/requesters/302aichatcmpl.yaml | 2 ++ pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml b/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml index 2d9df778..754a9078 100644 --- a/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./302aichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml b/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml index ca57c31c..2b7f9a70 100644 --- a/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./compsharechatcmpl.py From 39c062f73e83181d8ec57d42600cf00caf86f9a5 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 5 Jul 2025 21:56:17 +0800 Subject: [PATCH 15/60] perf: format --- .../http/controller/groups/knowledge_base.py | 65 +++++++++---------- .../controller/groups/pipelines/pipelines.py | 2 +- 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge_base.py index f9aa09e0..e9606a3d 100644 --- a/pkg/api/http/controller/groups/knowledge_base.py +++ b/pkg/api/http/controller/groups/knowledge_base.py @@ -1,49 +1,36 @@ import quart from .. import group + @group.group_class('knowledge_base', '/api/v1/knowledge/bases') class KnowledgeBaseRouterGroup(group.RouterGroup): - # 定义成功方法 def success(self, code=0, data=None, msg: str = 'ok') -> quart.Response: - return quart.jsonify({ - "code": code, - "data": data or {}, - "msg": msg - }) - + return quart.jsonify({'code': code, 'data': data or {}, 'msg': msg}) - async def initialize(self) -> None: - - @self.route('', methods=['POST', 'GET']) async def _() -> str: - if quart.request.method == 'GET': knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases() bases_list = [ { - "uuid": kb.id, - "name": kb.name, - "description": kb.description, - } for kb in knowledge_bases + 'uuid': kb.id, + 'name': kb.name, + 'description': kb.description, + } + for kb in knowledge_bases ] - return self.success(code=0, - data={'bases': bases_list}, - msg='ok') + return self.success(code=0, data={'bases': bases_list}, msg='ok') json_data = await quart.request.json knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( - json_data.get('name'), - json_data.get('description') + json_data.get('name'), json_data.get('description') ) - return self.success(code=0, - data={}, - msg='ok') + _ = knowledge_base_uuid + return self.success(code=0, data={}, msg='ok') - - @self.route('/', methods=['GET','DELETE']) + @self.route('/', methods=['GET', 'DELETE']) async def _(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid) @@ -54,11 +41,11 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): return self.success( code=0, data={ - "name": knowledge_base.name, - "description": knowledge_base.description, - "uuid": knowledge_base.id + 'name': knowledge_base.name, + 'description': knowledge_base.description, + 'uuid': knowledge_base.id, }, - msg='ok' + msg='ok', ) elif quart.request.method == 'DELETE': await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid) @@ -68,15 +55,21 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): async def _(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) - return self.success(code=0,data=[{ - "id": file.id, - "file_name": file.file_name, - "status": file.status - } for file in files],msg='ok') - + return self.success( + code=0, + data=[ + { + 'id': file.id, + 'file_name': file.file_name, + 'status': file.status, + } + for file in files + ], + msg='ok', + ) + # delete specific file in knowledge base @self.route('//files/', methods=['DELETE']) async def _(knowledge_base_uuid: str, file_id: str) -> str: await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) return self.success(code=0, msg='ok') - diff --git a/pkg/api/http/controller/groups/pipelines/pipelines.py b/pkg/api/http/controller/groups/pipelines/pipelines.py index 1a8036cc..96ca239a 100644 --- a/pkg/api/http/controller/groups/pipelines/pipelines.py +++ b/pkg/api/http/controller/groups/pipelines/pipelines.py @@ -2,7 +2,7 @@ from __future__ import annotations import quart -from .. import group +from ... import group @group.group_class('pipelines', '/api/v1/pipelines') From 8d28ace25276820714d33c5aedf359b48d0faf3e Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 5 Jul 2025 21:56:54 +0800 Subject: [PATCH 16/60] perf: ruff check --fix --- libs/wechatpad_api/__init__.py | 2 +- libs/wechatpad_api/api/chatroom.py | 6 +- libs/wechatpad_api/api/downloadpai.py | 25 +- libs/wechatpad_api/api/friend.py | 5 - libs/wechatpad_api/api/login.py | 60 +-- libs/wechatpad_api/api/message.py | 111 ++--- libs/wechatpad_api/util/http_util.py | 48 +- pkg/entity/persistence/vector.py | 11 +- pkg/platform/sources/aiocqhttp.py | 229 ++++++---- pkg/platform/sources/discord.py | 20 +- pkg/platform/sources/lark.py | 8 +- pkg/platform/sources/nakuru.py | 5 +- pkg/platform/sources/officialaccount.py | 4 +- pkg/platform/sources/qqofficial.py | 9 +- pkg/platform/sources/slack.py | 8 +- pkg/platform/sources/telegram.py | 4 +- pkg/platform/sources/wechatpad.py | 425 +++++++----------- pkg/platform/sources/wecom.py | 6 +- pkg/platform/sources/wecomcs.py | 6 +- pkg/rag/knowledge/services/database.py | 32 +- .../knowledge/services/embedding_models.py | 165 +++---- pkg/rag/knowledge/services/parser.py | 128 +++--- pkg/rag/knowledge/services/retriever.py | 67 +-- 23 files changed, 647 insertions(+), 737 deletions(-) diff --git a/libs/wechatpad_api/__init__.py b/libs/wechatpad_api/__init__.py index 23c23fb2..9ac533f7 100644 --- a/libs/wechatpad_api/__init__.py +++ b/libs/wechatpad_api/__init__.py @@ -1 +1 @@ -from .client import WeChatPadClient \ No newline at end of file +from .client import WeChatPadClient as WeChatPadClient diff --git a/libs/wechatpad_api/api/chatroom.py b/libs/wechatpad_api/api/chatroom.py index a7af207c..2d9281a2 100644 --- a/libs/wechatpad_api/api/chatroom.py +++ b/libs/wechatpad_api/api/chatroom.py @@ -1,4 +1,4 @@ -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json class ChatRoomApi: @@ -7,8 +7,6 @@ class ChatRoomApi: self.token = token def get_chatroom_member_detail(self, chatroom_name): - params = { - "ChatRoomName": chatroom_name - } + params = {'ChatRoomName': chatroom_name} url = self.base_url + '/group/GetChatroomMemberDetail' return post_json(url, token=self.token, data=params) diff --git a/libs/wechatpad_api/api/downloadpai.py b/libs/wechatpad_api/api/downloadpai.py index a82a5674..2d45fac6 100644 --- a/libs/wechatpad_api/api/downloadpai.py +++ b/libs/wechatpad_api/api/downloadpai.py @@ -1,32 +1,23 @@ -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json import httpx import base64 + class DownloadApi: def __init__(self, base_url, token): self.base_url = base_url self.token = token def send_download(self, aeskey, file_type, file_url): - json_data = { - "AesKey": aeskey, - "FileType": file_type, - "FileURL": file_url - } - url = self.base_url + "/message/SendCdnDownload" + json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url} + url = self.base_url + '/message/SendCdnDownload' return post_json(url, token=self.token, data=json_data) - def get_msg_voice(self,buf_id, length, new_msgid): - json_data = { - "Bufid": buf_id, - "Length": length, - "NewMsgId": new_msgid, - "ToUserName": "" - } - url = self.base_url + "/message/GetMsgVoice" + def get_msg_voice(self, buf_id, length, new_msgid): + json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''} + url = self.base_url + '/message/GetMsgVoice' return post_json(url, token=self.token, data=json_data) - async def download_url_to_base64(self, download_url): async with httpx.AsyncClient() as client: response = await client.get(download_url) @@ -36,4 +27,4 @@ class DownloadApi: base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式 return base64_str else: - raise Exception('获取文件失败') \ No newline at end of file + raise Exception('获取文件失败') diff --git a/libs/wechatpad_api/api/friend.py b/libs/wechatpad_api/api/friend.py index 00701a5d..a7a448aa 100644 --- a/libs/wechatpad_api/api/friend.py +++ b/libs/wechatpad_api/api/friend.py @@ -1,11 +1,6 @@ -from libs.wechatpad_api.util.http_util import post_json,async_request -from typing import List, Dict, Any, Optional - - class FriendApi: """联系人API类,处理所有与联系人相关的操作""" def __init__(self, base_url: str, token: str): self.base_url = base_url self.token = token - diff --git a/libs/wechatpad_api/api/login.py b/libs/wechatpad_api/api/login.py index 142a3c85..4aa4ae8d 100644 --- a/libs/wechatpad_api/api/login.py +++ b/libs/wechatpad_api/api/login.py @@ -1,37 +1,34 @@ -from libs.wechatpad_api.util.http_util import async_request,post_json,get_json +from libs.wechatpad_api.util.http_util import post_json, get_json class LoginApi: def __init__(self, base_url: str, token: str = None, admin_key: str = None): - ''' + """ Args: base_url: 原始路径 token: token admin_key: 管理员key - ''' + """ self.base_url = base_url self.token = token # self.admin_key = admin_key - def get_token(self, admin_key, day: int=365): + def get_token(self, admin_key, day: int = 365): # 获取普通token - url = f"{self.base_url}/admin/GenAuthKey1" - json_data = { - "Count": 1, - "Days": day - } + url = f'{self.base_url}/admin/GenAuthKey1' + json_data = {'Count': 1, 'Days': day} return post_json(base_url=url, token=admin_key, data=json_data) - def get_login_qr(self, Proxy: str = ""): - ''' + def get_login_qr(self, Proxy: str = ''): + """ Args: Proxy:异地使用时代理 Returns:json数据 - ''' + """ """ { @@ -49,54 +46,37 @@ class LoginApi: } """ - #获取登录二维码 - url = f"{self.base_url}/login/GetLoginQrCodeNew" + # 获取登录二维码 + url = f'{self.base_url}/login/GetLoginQrCodeNew' check = False - if Proxy != "": + if Proxy != '': check = True - json_data = { - "Check": check, - "Proxy": Proxy - } + json_data = {'Check': check, 'Proxy': Proxy} return post_json(base_url=url, token=self.token, data=json_data) - def get_login_status(self): # 获取登录状态 url = f'{self.base_url}/login/GetLoginStatus' return get_json(base_url=url, token=self.token) - - def logout(self): # 退出登录 url = f'{self.base_url}/login/LogOut' return post_json(base_url=url, token=self.token) - - - - def wake_up_login(self, Proxy: str = ""): + def wake_up_login(self, Proxy: str = ''): # 唤醒登录 url = f'{self.base_url}/login/WakeUpLogin' check = False - if Proxy != "": + if Proxy != '': check = True - json_data = { - "Check": check, - "Proxy": "" - } + json_data = {'Check': check, 'Proxy': ''} return post_json(base_url=url, token=self.token, data=json_data) - - - def login(self,admin_key): + def login(self, admin_key): login_status = self.get_login_status() - if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信": - print("token已经失效,重新获取") + if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信': + print('token已经失效,重新获取') token_data = self.get_token(admin_key) - self.token = token_data["Data"][0] - - - + self.token = token_data['Data'][0] diff --git a/libs/wechatpad_api/api/message.py b/libs/wechatpad_api/api/message.py index 2089ce96..cca76313 100644 --- a/libs/wechatpad_api/api/message.py +++ b/libs/wechatpad_api/api/message.py @@ -1,5 +1,4 @@ - -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json class MessageApi: @@ -7,8 +6,8 @@ class MessageApi: self.base_url = base_url self.token = token - def post_text(self, to_wxid, content, ats: list= []): - ''' + def post_text(self, to_wxid, content, ats: list = []): + """ Args: app_id: 微信id @@ -18,106 +17,64 @@ class MessageApi: Returns: - ''' - url = self.base_url + "/message/SendTextMessage" + """ + url = self.base_url + '/message/SendTextMessage' """发送文字消息""" json_data = { - "MsgItem": [ - { - "AtWxIDList": ats, - "ImageContent": "", - "MsgType": 0, - "TextContent": content, - "ToUserName": to_wxid - } - ] - } - return post_json(base_url=url, token=self.token, data=json_data) + 'MsgItem': [ + {'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid} + ] + } + return post_json(base_url=url, token=self.token, data=json_data) - - - - def post_image(self, to_wxid, img_url, ats: list= []): + def post_image(self, to_wxid, img_url, ats: list = []): """发送图片消息""" # 这里好像可以尝试发送多个暂时未测试 json_data = { - "MsgItem": [ - { - "AtWxIDList": ats, - "ImageContent": img_url, - "MsgType": 0, - "TextContent": '', - "ToUserName": to_wxid - } + 'MsgItem': [ + {'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid} ] } - url = self.base_url + "/message/SendImageMessage" + url = self.base_url + '/message/SendImageMessage' return post_json(base_url=url, token=self.token, data=json_data) def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration): """发送语音消息""" json_data = { - "ToUserName": to_wxid, - "VoiceData": voice_data, - "VoiceFormat": voice_forma, - "VoiceSecond": voice_duration + 'ToUserName': to_wxid, + 'VoiceData': voice_data, + 'VoiceFormat': voice_forma, + 'VoiceSecond': voice_duration, } - url = self.base_url + "/message/SendVoice" + url = self.base_url + '/message/SendVoice' return post_json(base_url=url, token=self.token, data=json_data) - - - - def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag): """发送名片消息""" param = { - "CardAlias": alias, - "CardFlag": flag, - "CardNickName": nick_name, - "CardWxId": name_card_wxid, - "ToUserName": to_wxid + 'CardAlias': alias, + 'CardFlag': flag, + 'CardNickName': nick_name, + 'CardWxId': name_card_wxid, + 'ToUserName': to_wxid, } - url = f"{self.base_url}/message/ShareCardMessage" + url = f'{self.base_url}/message/ShareCardMessage' return post_json(base_url=url, token=self.token, data=param) - def post_emoji(self, to_wxid, emoji_md5, emoji_size:int=0): + def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0): """发送emoji消息""" - json_data = { - "EmojiList": [ - { - "EmojiMd5": emoji_md5, - "EmojiSize": emoji_size, - "ToUserName": to_wxid - } - ] - } - url = f"{self.base_url}/message/SendEmojiMessage" + json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]} + url = f'{self.base_url}/message/SendEmojiMessage' return post_json(base_url=url, token=self.token, data=json_data) - def post_app_msg(self, to_wxid,xml_data, contenttype:int=0): + def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0): """发送appmsg消息""" - json_data = { - "AppList": [ - { - "ContentType": contenttype, - "ContentXML": xml_data, - "ToUserName": to_wxid - } - ] - } - url = f"{self.base_url}/message/SendAppMessage" + json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]} + url = f'{self.base_url}/message/SendAppMessage' return post_json(base_url=url, token=self.token, data=json_data) - - def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time): """撤回消息""" - param = { - "ClientMsgId": msg_id, - "CreateTime": create_time, - "NewMsgId": new_msg_id, - "ToUserName": to_wxid - } - url = f"{self.base_url}/message/RevokeMsg" - return post_json(base_url=url, token=self.token, data=param) \ No newline at end of file + param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid} + url = f'{self.base_url}/message/RevokeMsg' + return post_json(base_url=url, token=self.token, data=param) diff --git a/libs/wechatpad_api/util/http_util.py b/libs/wechatpad_api/util/http_util.py index 754003e9..447c29df 100644 --- a/libs/wechatpad_api/util/http_util.py +++ b/libs/wechatpad_api/util/http_util.py @@ -1,10 +1,9 @@ import requests +import aiohttp + def post_json(base_url, token, data=None): - headers = { - 'Content-Type': 'application/json' - } - + headers = {'Content-Type': 'application/json'} url = base_url + f'?key={token}' @@ -18,14 +17,12 @@ def post_json(base_url, token, data=None): else: raise RuntimeError(response.text) except Exception as e: - print(f"http请求失败, url={url}, exception={e}") + print(f'http请求失败, url={url}, exception={e}') raise RuntimeError(str(e)) -def get_json(base_url, token): - headers = { - 'Content-Type': 'application/json' - } +def get_json(base_url, token): + headers = {'Content-Type': 'application/json'} url = base_url + f'?key={token}' @@ -39,21 +36,18 @@ def get_json(base_url, token): else: raise RuntimeError(response.text) except Exception as e: - print(f"http请求失败, url={url}, exception={e}") + print(f'http请求失败, url={url}, exception={e}') raise RuntimeError(str(e)) -import aiohttp -import asyncio - async def async_request( - base_url: str, - token_key: str, - method: str = 'POST', - params: dict = None, - # headers: dict = None, - data: dict = None, - json: dict = None + base_url: str, + token_key: str, + method: str = 'POST', + params: dict = None, + # headers: dict = None, + data: dict = None, + json: dict = None, ): """ 通用异步请求函数 @@ -67,18 +61,11 @@ async def async_request( :param json: JSON数据 :return: 响应文本 """ - headers = { - 'Content-Type': 'application/json' - } - url = f"{base_url}?key={token_key}" + headers = {'Content-Type': 'application/json'} + url = f'{base_url}?key={token_key}' async with aiohttp.ClientSession() as session: async with session.request( - method=method, - url=url, - params=params, - headers=headers, - data=data, - json=json + method=method, url=url, params=params, headers=headers, data=data, json=json ) as response: response.raise_for_status() # 如果状态码不是200,抛出异常 result = await response.json() @@ -89,4 +76,3 @@ async def async_request( # return await result # else: # raise RuntimeError("请求失败",response.text) - diff --git a/pkg/entity/persistence/vector.py b/pkg/entity/persistence/vector.py index 84d1dfb1..465125f5 100644 --- a/pkg/entity/persistence/vector.py +++ b/pkg/entity/persistence/vector.py @@ -1,14 +1,13 @@ -from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary -from sqlalchemy.orm import declarative_base, sessionmaker, relationship -from datetime import datetime -import numpy as np # 用于处理从LargeBinary转换回来的embedding +from sqlalchemy import Column, Integer, ForeignKey, LargeBinary +from sqlalchemy.orm import declarative_base, relationship Base = declarative_base() + class Vector(Base): __tablename__ = 'vectors' id = Column(Integer, primary_key=True, index=True) chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) - embedding = Column(LargeBinary) # Store embeddings as binary + embedding = Column(LargeBinary) # Store embeddings as binary - chunk = relationship("Chunk", back_populates="vector") \ No newline at end of file + chunk = relationship('Chunk', back_populates='vector') diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 3f3ef512..2730874f 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -16,7 +16,6 @@ from ..logger import EventLogger class AiocqhttpMessageConverter(adapter.MessageConverter): - @staticmethod async def yiri2target( message_chain: platform_message.MessageChain, @@ -62,87 +61,170 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): for node in msg.node_list: msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0]) elif isinstance(msg, platform_message.File): - msg_list.append({"type":"file", "data":{'file': msg.url, "name": msg.name}}) + msg_list.append({'type': 'file', 'data': {'file': msg.url, 'name': msg.name}}) elif isinstance(msg, platform_message.Face): - if msg.face_type=='face': + if msg.face_type == 'face': msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id)) - elif msg.face_type=='rps': + elif msg.face_type == 'rps': msg_list.append(aiocqhttp.MessageSegment.rps()) - elif msg.face_type=='dice': + elif msg.face_type == 'dice': msg_list.append(aiocqhttp.MessageSegment.dice()) - else: msg_list.append(aiocqhttp.MessageSegment.text(str(msg))) return msg_list, msg_id, msg_time @staticmethod - async def target2yiri(message: str, message_id: int = -1,bot=None): + async def target2yiri(message: str, message_id: int = -1, bot=None): print(message) message = aiocqhttp.Message(message) def get_face_name(face_id): face_code_dict = { - "2": '好色', - "4": "得意", "5": "流泪", "8": "睡", "9": "大哭", "10": "尴尬", "12": "调皮", "14": "微笑", "16": "酷", - "21": "可爱", - "23": "傲慢", "24": "饥饿", "25": "困", "26": "惊恐", "27": "流汗", "28": "憨笑", "29": "悠闲", - "30": "奋斗", - "32": "疑问", "33": "嘘", "34": "晕", "38": "敲打", "39": "再见", "41": "发抖", "42": "爱情", - "43": "跳跳", - "49": "拥抱", "53": "蛋糕", "60": "咖啡", "63": "玫瑰", "66": "爱心", "74": "太阳", "75": "月亮", - "76": "赞", - "78": "握手", "79": "胜利", "85": "飞吻", "89": "西瓜", "96": "冷汗", "97": "擦汗", "98": "抠鼻", - "99": "鼓掌", - "100": "糗大了", "101": "坏笑", "102": "左哼哼", "103": "右哼哼", "104": "哈欠", "106": "委屈", - "109": "左亲亲", - "111": "可怜", "116": "示爱", "118": "抱拳", "120": "拳头", "122": "爱你", "123": "NO", "124": "OK", - "125": "转圈", - "129": "挥手", "144": "喝彩", "147": "棒棒糖", "171": "茶", "173": "泪奔", "174": "无奈", "175": "卖萌", - "176": "小纠结", "179": "doge", "180": "惊喜", "181": "骚扰", "182": "笑哭", "183": "我最美", - "201": "点赞", - "203": "托脸", "212": "托腮", "214": "啵啵", "219": "蹭一蹭", "222": "抱抱", "227": "拍手", - "232": "佛系", - "240": "喷脸", "243": "甩头", "246": "加油抱抱", "262": "脑阔疼", "264": "捂脸", "265": "辣眼睛", - "266": "哦哟", - "267": "头秃", "268": "问号脸", "269": "暗中观察", "270": "emm", "271": "吃瓜", "272": "呵呵哒", - "273": "我酸了", - "277": "汪汪", "278": "汗", "281": "无眼笑", "282": "敬礼", "284": "面无表情", "285": "摸鱼", - "287": "哦", - "289": "睁眼", "290": "敲开心", "293": "摸锦鲤", "294": "期待", "297": "拜谢", "298": "元宝", - "299": "牛啊", - "305": "右亲亲", "306": "牛气冲天", "307": "喵喵", "314": "仔细分析", "315": "加油", "318": "崇拜", - "319": "比心", - "320": "庆祝", "322": "拒绝", "324": "吃糖", "326": "生气" + '2': '好色', + '4': '得意', + '5': '流泪', + '8': '睡', + '9': '大哭', + '10': '尴尬', + '12': '调皮', + '14': '微笑', + '16': '酷', + '21': '可爱', + '23': '傲慢', + '24': '饥饿', + '25': '困', + '26': '惊恐', + '27': '流汗', + '28': '憨笑', + '29': '悠闲', + '30': '奋斗', + '32': '疑问', + '33': '嘘', + '34': '晕', + '38': '敲打', + '39': '再见', + '41': '发抖', + '42': '爱情', + '43': '跳跳', + '49': '拥抱', + '53': '蛋糕', + '60': '咖啡', + '63': '玫瑰', + '66': '爱心', + '74': '太阳', + '75': '月亮', + '76': '赞', + '78': '握手', + '79': '胜利', + '85': '飞吻', + '89': '西瓜', + '96': '冷汗', + '97': '擦汗', + '98': '抠鼻', + '99': '鼓掌', + '100': '糗大了', + '101': '坏笑', + '102': '左哼哼', + '103': '右哼哼', + '104': '哈欠', + '106': '委屈', + '109': '左亲亲', + '111': '可怜', + '116': '示爱', + '118': '抱拳', + '120': '拳头', + '122': '爱你', + '123': 'NO', + '124': 'OK', + '125': '转圈', + '129': '挥手', + '144': '喝彩', + '147': '棒棒糖', + '171': '茶', + '173': '泪奔', + '174': '无奈', + '175': '卖萌', + '176': '小纠结', + '179': 'doge', + '180': '惊喜', + '181': '骚扰', + '182': '笑哭', + '183': '我最美', + '201': '点赞', + '203': '托脸', + '212': '托腮', + '214': '啵啵', + '219': '蹭一蹭', + '222': '抱抱', + '227': '拍手', + '232': '佛系', + '240': '喷脸', + '243': '甩头', + '246': '加油抱抱', + '262': '脑阔疼', + '264': '捂脸', + '265': '辣眼睛', + '266': '哦哟', + '267': '头秃', + '268': '问号脸', + '269': '暗中观察', + '270': 'emm', + '271': '吃瓜', + '272': '呵呵哒', + '273': '我酸了', + '277': '汪汪', + '278': '汗', + '281': '无眼笑', + '282': '敬礼', + '284': '面无表情', + '285': '摸鱼', + '287': '哦', + '289': '睁眼', + '290': '敲开心', + '293': '摸锦鲤', + '294': '期待', + '297': '拜谢', + '298': '元宝', + '299': '牛啊', + '305': '右亲亲', + '306': '牛气冲天', + '307': '喵喵', + '314': '仔细分析', + '315': '加油', + '318': '崇拜', + '319': '比心', + '320': '庆祝', + '322': '拒绝', + '324': '吃糖', + '326': '生气', } - return face_code_dict.get(face_id,'') + return face_code_dict.get(face_id, '') async def process_message_data(msg_data, reply_list): - if msg_data["type"] == "image": - image_base64, image_format = await image.qq_image_url_to_base64(msg_data["data"]['url']) - reply_list.append( - platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) + if msg_data['type'] == 'image': + image_base64, image_format = await image.qq_image_url_to_base64(msg_data['data']['url']) + reply_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) - elif msg_data["type"] == "text": - reply_list.append(platform_message.Plain(text=msg_data["data"]["text"])) + elif msg_data['type'] == 'text': + reply_list.append(platform_message.Plain(text=msg_data['data']['text'])) - elif msg_data["type"] == "forward": # 这里来应该传入转发消息组,暂时传入qoute - for forward_msg_datas in msg_data["data"]["content"]: - for forward_msg_data in forward_msg_datas["message"]: + elif msg_data['type'] == 'forward': # 这里来应该传入转发消息组,暂时传入qoute + for forward_msg_datas in msg_data['data']['content']: + for forward_msg_data in forward_msg_datas['message']: await process_message_data(forward_msg_data, reply_list) - elif msg_data["type"] == "at": - if msg_data["data"]['qq'] == 'all': + elif msg_data['type'] == 'at': + if msg_data['data']['qq'] == 'all': reply_list.append(platform_message.AtAll()) else: reply_list.append( platform_message.At( - target=msg_data["data"]['qq'], + target=msg_data['data']['qq'], ) ) - yiri_msg_list = [] yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) @@ -161,10 +243,10 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): elif msg.type == 'text': yiri_msg_list.append(platform_message.Plain(text=msg.data['text'])) elif msg.type == 'image': - emoji_id = msg.data.get("emoji_package_id", None) + emoji_id = msg.data.get('emoji_package_id', None) if emoji_id: face_id = emoji_id - face_name = msg.data.get("summary", '') + face_name = msg.data.get('summary', '') image_msg = platform_message.Face(face_id=face_id, face_name=face_name) else: image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url']) @@ -178,14 +260,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): # await process_message_data(msg_data, yiri_msg_list) pass - elif msg.type == 'reply': # 此处处理引用消息传入Qoute - msg_datas = await bot.get_msg(message_id=msg.data["id"]) + msg_datas = await bot.get_msg(message_id=msg.data['id']) - for msg_data in msg_datas["message"]: + for msg_data in msg_datas['message']: await process_message_data(msg_data, reply_list) - reply_msg = platform_message.Quote(message_id=msg.data["id"],sender_id=msg_datas["user_id"],origin=reply_list) + reply_msg = platform_message.Quote( + message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list + ) yiri_msg_list.append(reply_msg) elif msg.type == 'file': @@ -194,49 +277,36 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): file_data = await bot.get_file(file_id=file_id) file_name = file_data.get('file_name') file_path = file_data.get('file') + _ = file_path file_url = file_data.get('file_url') file_size = file_data.get('file_size') - yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size)) + yiri_msg_list.append(platform_message.File(id=file_id, name=file_name, url=file_url, size=file_size)) elif msg.type == 'face': face_id = msg.data['id'] face_name = msg.data['raw']['faceText'] if not face_name: face_name = get_face_name(face_id) - yiri_msg_list.append(platform_message.Face(face_id=int(face_id),face_name=face_name.replace('/',''))) + yiri_msg_list.append(platform_message.Face(face_id=int(face_id), face_name=face_name.replace('/', ''))) elif msg.type == 'rps': face_id = msg.data['result'] - yiri_msg_list.append(platform_message.Face(face_type="rps",face_id=int(face_id),face_name='猜拳')) + yiri_msg_list.append(platform_message.Face(face_type='rps', face_id=int(face_id), face_name='猜拳')) elif msg.type == 'dice': face_id = msg.data['result'] - yiri_msg_list.append(platform_message.Face(face_type='dice',face_id=int(face_id),face_name='骰子')) - - - - - - - - + yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子')) chain = platform_message.MessageChain(yiri_msg_list) return chain - - - - class AiocqhttpEventConverter(adapter.EventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int): return event.source_platform_object @staticmethod - async def target2yiri(event: aiocqhttp.Event,bot=None): - yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id,bot) - - + async def target2yiri(event: aiocqhttp.Event, bot=None): + yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot) if event.message_type == 'group': permission = 'MEMBER' @@ -316,7 +386,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] if target_type == 'group': - await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg) elif target_type == 'person': await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg) @@ -345,7 +414,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id try: - return await callback(await self.event_converter.target2yiri(event,self.bot), self) + return await callback(await self.event_converter.target2yiri(event, self.bot), self) except Exception: await self.logger.error(f'Error in on_message: {traceback.format_exc()}') traceback.print_exc() diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index 4f5cac28..6cc09a72 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -8,7 +8,6 @@ import base64 import uuid import os import datetime -import io import aiohttp @@ -78,10 +77,10 @@ class DiscordMessageConverter(adapter.MessageConverter): # 确保路径没有空字节 clean_path = ele.path.replace('\x00', '') clean_path = os.path.abspath(clean_path) - + if not os.path.exists(clean_path): continue # 跳过不存在的文件 - + try: with open(clean_path, 'rb') as f: image_bytes = f.read() @@ -101,12 +100,13 @@ class DiscordMessageConverter(adapter.MessageConverter): filename = f'{uuid.uuid4()}.webp' # 默认保持PNG except Exception as e: - print(f"Error reading image file {clean_path}: {e}") + print(f'Error reading image file {clean_path}: {e}') continue # 跳过读取失败的文件 if image_bytes: # 使用BytesIO创建文件对象,避免路径问题 import io + image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename)) elif isinstance(ele, platform_message.Plain): text_string += ele.text @@ -261,25 +261,25 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): msg_to_send, image_files = await self.message_converter.yiri2target(message) - + try: # 获取频道对象 channel = self.bot.get_channel(int(target_id)) if channel is None: # 如果本地缓存中没有,尝试从API获取 channel = await self.bot.fetch_channel(int(target_id)) - + args = { 'content': msg_to_send, } - + if len(image_files) > 0: args['files'] = image_files - + await channel.send(**args) - + except Exception as e: - await self.logger.error(f"Discord send_message failed: {e}") + await self.logger.error(f'Discord send_message failed: {e}') raise e async def reply_message( diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index d1116362..f8faf522 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -378,15 +378,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter): if 'im.message.receive_v1' == type: try: event = await self.event_converter.target2yiri(p2v1, self.api_client) - except Exception as e: - await self.logger.error(f"Error in lark callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) return {'code': 200, 'message': 'ok'} - except Exception as e: - await self.logger.error(f"Error in lark callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') return {'code': 500, 'message': 'error'} async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1): diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 389a2db1..16ad54db 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): content=content_list, ) nakuru_forward_node_list.append(nakuru_forward_node) - except Exception as e: + except Exception: import traceback + traceback.print_exc() nakuru_msg_list.append(nakuru_forward_node_list) @@ -276,7 +277,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): # 注册监听器 self.bot.receiver(source_cls.__name__)(listener_wrapper) except Exception as e: - self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}") + self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}') raise e def unregister_listener( diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py index 030db56d..3fc1e393 100644 --- a/pkg/platform/sources/officialaccount.py +++ b/pkg/platform/sources/officialaccount.py @@ -125,8 +125,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in officialaccount callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in officialaccount callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/platform/sources/qqofficial.py b/pkg/platform/sources/qqofficial.py index c61afea4..63ab531f 100644 --- a/pkg/platform/sources/qqofficial.py +++ b/pkg/platform/sources/qqofficial.py @@ -154,10 +154,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): raise ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员') self.bot = QQOfficialClient( - app_id=config['appid'], - secret=config['secret'], - token=config['token'], - logger=self.logger + app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger ) async def reply_message( @@ -224,8 +221,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = 'justbot' try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in qqofficial callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in qqofficial callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message) diff --git a/pkg/platform/sources/slack.py b/pkg/platform/sources/slack.py index 6dfcff59..1bd5aa2d 100644 --- a/pkg/platform/sources/slack.py +++ b/pkg/platform/sources/slack.py @@ -104,7 +104,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter): if missing_keys: raise ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员') - self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger) + self.bot = SlackClient( + bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger + ) async def reply_message( self, @@ -139,8 +141,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = 'SlackBot' try: return await callback(await self.event_converter.target2yiri(event, self.bot), self) - except Exception as e: - await self.logger.error(f"Error in slack callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in slack callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('im')(on_message) diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index 266d994e..c2fcc22e 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -160,8 +160,8 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): try: lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id) await self.listeners[type(lb_event)](lb_event, self) - except Exception as e: - await self.logger.error(f"Error in telegram callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}') self.application = ApplicationBuilder().token(self.config['token']).build() self.bot = self.application.bot diff --git a/pkg/platform/sources/wechatpad.py b/pkg/platform/sources/wechatpad.py index fdd4a69b..75cad727 100644 --- a/pkg/platform/sources/wechatpad.py +++ b/pkg/platform/sources/wechatpad.py @@ -1,5 +1,4 @@ import requests -import websockets import websocket import json import time @@ -10,53 +9,40 @@ from libs.wechatpad_api.client import WeChatPadClient import typing import asyncio import traceback -import time import re import base64 -import uuid -import json -import os import copy -import datetime import threading import quart -import aiohttp from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities -from ...utils import image from ..logger import EventLogger import xml.etree.ElementTree as ET -from typing import Optional, List, Tuple +from typing import Optional, Tuple from functools import partial import logging -class WeChatPadMessageConverter(adapter.MessageConverter): +class WeChatPadMessageConverter(adapter.MessageConverter): def __init__(self, config: dict): self.config = config - self.bot = WeChatPadClient(self.config["wechatpad_url"],self.config["token"]) - self.logger = logging.getLogger("WeChatPadMessageConverter") + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) + self.logger = logging.getLogger('WeChatPadMessageConverter') @staticmethod - async def yiri2target( - message_chain: platform_message.MessageChain - ) -> list[dict]: + async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]: content_list = [] - current_file_path = os.path.abspath(__file__) - - for component in message_chain: if isinstance(component, platform_message.At): - content_list.append({"type": "at", "target": component.target}) + content_list.append({'type': 'at', 'target': component.target}) elif isinstance(component, platform_message.Plain): - content_list.append({"type": "text", "content": component.text}) + content_list.append({'type': 'text', 'content': component.text}) elif isinstance(component, platform_message.Image): if component.url: async with httpx.AsyncClient() as client: @@ -68,15 +54,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter): else: raise Exception('获取文件失败') # pass - content_list.append({"type": "image", "image": base64_str}) + content_list.append({'type': 'image', 'image': base64_str}) elif component.base64: - content_list.append({"type": "image", "image": component.base64}) + content_list.append({'type': 'image', 'image': component.base64}) elif isinstance(component, platform_message.WeChatEmoji): content_list.append( - {'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size}) + {'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size} + ) elif isinstance(component, platform_message.Voice): - content_list.append({"type": "voice", "data": component.url, "duration": component.length, "forma": 0}) + content_list.append({'type': 'voice', 'data': component.url, 'duration': component.length, 'forma': 0}) elif isinstance(component, platform_message.WeChatAppMsg): content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg}) elif isinstance(component, platform_message.Forward): @@ -86,28 +73,23 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return content_list - - async def target2yiri( - self, - message: dict, - bot_account_id: str - ) -> platform_message.MessageChain: + async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain: """外部消息转平台消息""" # 数据预处理 message_list = [] ats_bot = False # 是否被@ - content = message["content"]["str"] + content = message['content']['str'] content_no_preifx = content # 群消息则去掉前缀 is_group_message = self._is_group_message(message) if is_group_message: ats_bot = self._ats_bot(message, bot_account_id) - if "@所有人" in content: + if '@所有人' in content: message_list.append(platform_message.AtAll()) elif ats_bot: message_list.append(platform_message.At(target=bot_account_id)) content_no_preifx, _ = self._extract_content_and_sender(content) - msg_type = message["msg_type"] + msg_type = message['msg_type'] # 映射消息类型到处理器方法 handler_map = { @@ -129,11 +111,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_list) - async def _handler_text( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理文本消息 (msg_type=1)""" if message and self._is_group_message(message): pattern = r'@\S{1,20}' @@ -141,16 +119,12 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain([platform_message.Plain(content_no_preifx)]) - async def _handler_image( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理图像消息 (msg_type=3)""" try: image_xml = content_no_preifx if not image_xml: - return platform_message.MessageChain([platform_message.Unknown("[图片内容为空]")]) + return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')]) root = ET.fromstring(image_xml) # 提取img标签的属性 @@ -160,28 +134,22 @@ class WeChatPadMessageConverter(adapter.MessageConverter): cdnthumburl = img_tag.get('cdnthumburl') # cdnmidimgurl = img_tag.get('cdnmidimgurl') - image_data = self.bot.cdn_download(aeskey=aeskey, file_type=1, file_url=cdnthumburl) - if image_data["Data"]['FileData'] == '': + if image_data['Data']['FileData'] == '': image_data = self.bot.cdn_download(aeskey=aeskey, file_type=2, file_url=cdnthumburl) - base64_str = image_data["Data"]['FileData'] + base64_str = image_data['Data']['FileData'] # self.logger.info(f"data:image/png;base64,{base64_str}") - elements = [ - platform_message.Image(base64=f"data:image/png;base64,{base64_str}"), + platform_message.Image(base64=f'data:image/png;base64,{base64_str}'), # platform_message.WeChatForwardImage(xml_data=image_xml) # 微信消息转发 ] return platform_message.MessageChain(elements) except Exception as e: - self.logger.error(f"处理图片失败: {str(e)}") - return platform_message.MessageChain([platform_message.Unknown("[图片处理失败]")]) + self.logger.error(f'处理图片失败: {str(e)}') + return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')]) - async def _handler_voice( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理语音消息 (msg_type=34)""" message_List = [] try: @@ -197,39 +165,33 @@ class WeChatPadMessageConverter(adapter.MessageConverter): bufid = voicemsg.get('bufid') length = voicemsg.get('voicelength') voice_data = self.bot.get_msg_voice(buf_id=str(bufid), length=int(length), msgid=str(new_msg_id)) - audio_base64 = voice_data["Data"]['Base64'] + audio_base64 = voice_data['Data']['Base64'] # 验证语音数据有效性 if not audio_base64: - message_List.append(platform_message.Unknown(text="[语音内容为空]")) + message_List.append(platform_message.Unknown(text='[语音内容为空]')) return platform_message.MessageChain(message_List) # 转换为平台支持的语音格式(如 Silk 格式) - voice_element = platform_message.Voice( - base64=f"data:audio/silk;base64,{audio_base64}" - ) + voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}') message_List.append(voice_element) except KeyError as e: - self.logger.error(f"语音数据字段缺失: {str(e)}") - message_List.append(platform_message.Unknown(text="[语音数据解析失败]")) + self.logger.error(f'语音数据字段缺失: {str(e)}') + message_List.append(platform_message.Unknown(text='[语音数据解析失败]')) except Exception as e: - self.logger.error(f"处理语音消息异常: {str(e)}") - message_List.append(platform_message.Unknown(text="[语音处理失败]")) + self.logger.error(f'处理语音消息异常: {str(e)}') + message_List.append(platform_message.Unknown(text='[语音处理失败]')) return platform_message.MessageChain(message_List) - async def _handler_compound( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理复合消息 (msg_type=49),根据子类型分派""" try: xml_data = ET.fromstring(content_no_preifx) appmsg_data = xml_data.find('.//appmsg') if appmsg_data: - data_type = appmsg_data.findtext('.//type', "") + data_type = appmsg_data.findtext('.//type', '') # 二次分派处理器 sub_handler_map = { '57': self._handler_compound_quote, @@ -238,9 +200,9 @@ class WeChatPadMessageConverter(adapter.MessageConverter): '74': self._handler_compound_file, '33': self._handler_compound_mini_program, '36': self._handler_compound_mini_program, - '2000': partial(self._handler_compound_unsupported, text="[转账消息]"), - '2001': partial(self._handler_compound_unsupported, text="[红包消息]"), - '51': partial(self._handler_compound_unsupported, text="[视频号消息]"), + '2000': partial(self._handler_compound_unsupported, text='[转账消息]'), + '2001': partial(self._handler_compound_unsupported, text='[红包消息]'), + '51': partial(self._handler_compound_unsupported, text='[视频号消息]'), } handler = sub_handler_map.get(data_type, self._handler_compound_unsupported) @@ -251,56 +213,54 @@ class WeChatPadMessageConverter(adapter.MessageConverter): else: return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) except Exception as e: - self.logger.error(f"解析复合消息失败: {str(e)}") + self.logger.error(f'解析复合消息失败: {str(e)}') return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) async def _handler_compound_quote( - self, - message: Optional[dict], - xml_data: ET.Element + self, message: Optional[dict], xml_data: ET.Element ) -> platform_message.MessageChain: """处理引用消息 (data_type=57)""" message_list = [] -# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) + # self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) appmsg_data = xml_data.find('.//appmsg') - quote_data = "" # 引用原文 + quote_data = '' # 引用原文 quote_id = None # 引用消息的原发送者 tousername = None # 接收方: 所属微信的wxid - user_data = "" # 用户消息 + user_data = '' # 用户消息 sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member # 引用消息转发 if appmsg_data: - user_data = appmsg_data.findtext('.//title') or "" + user_data = appmsg_data.findtext('.//title') or '' quote_data = appmsg_data.find('.//refermsg').findtext('.//content') quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') - message_list.append( - platform_message.WeChatAppMsg( - app_msg=ET.tostring(appmsg_data, encoding='unicode')) - ) + message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode'))) if message: - tousername = message['to_user_name']["str"] - + tousername = message['to_user_name']['str'] + + _ = quote_id + _ = tousername + if quote_data: quote_data_message_list = platform_message.MessageChain() # 文本消息 try: - if "" not in quote_data: + if '' not in quote_data: quote_data_message_list.append(platform_message.Plain(quote_data)) else: # 引用消息展开 quote_data_xml = ET.fromstring(quote_data) - if quote_data_xml.find("img"): + if quote_data_xml.find('img'): quote_data_message_list.extend(await self._handler_image(None, quote_data)) - elif quote_data_xml.find("voicemsg"): + elif quote_data_xml.find('voicemsg'): quote_data_message_list.extend(await self._handler_voice(None, quote_data)) - elif quote_data_xml.find("videomsg"): + elif quote_data_xml.find('videomsg'): quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理 else: # appmsg quote_data_message_list.extend(await self._handler_compound(None, quote_data)) except Exception as e: - self.logger.error(f"处理引用消息异常 expcetion:{e}") + self.logger.error(f'处理引用消息异常 expcetion:{e}') quote_data_message_list.append(platform_message.Plain(quote_data)) message_list.append( platform_message.Quote( @@ -315,15 +275,11 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_list) - async def _handler_compound_file( - self, - message: dict, - xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理文件消息 (data_type=6)""" file_data = xml_data.find('.//appmsg') - if file_data.findtext('.//type', "") == "74": + if file_data.findtext('.//type', '') == '74': return None else: @@ -346,22 +302,21 @@ class WeChatPadMessageConverter(adapter.MessageConverter): file_data = self.bot.cdn_download(aeskey=aeskey, file_type=5, file_url=cdnthumburl) - file_base64 = file_data["Data"]['FileData'] + file_base64 = file_data['Data']['FileData'] # print(file_data) - file_size = file_data["Data"]['TotalSize'] + file_size = file_data['Data']['TotalSize'] # print(file_base64) - return platform_message.MessageChain([ - platform_message.WeChatFile(file_id=file_id, file_name=file_name, file_size=file_size, - file_base64=file_base64), - platform_message.WeChatForwardFile(xml_data=xml_data_str) - ]) + return platform_message.MessageChain( + [ + platform_message.WeChatFile( + file_id=file_id, file_name=file_name, file_size=file_size, file_base64=file_base64 + ), + platform_message.WeChatForwardFile(xml_data=xml_data_str), + ] + ) - async def _handler_compound_link( - self, - message: dict, - xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理链接消息(如公众号文章、外部网页)""" message_list = [] try: @@ -374,56 +329,38 @@ class WeChatPadMessageConverter(adapter.MessageConverter): link_title=appmsg.findtext('title', ''), link_desc=appmsg.findtext('des', ''), link_url=appmsg.findtext('url', ''), - link_thumb_url=appmsg.findtext("thumburl", '') # 这个字段拿不到 + link_thumb_url=appmsg.findtext('thumburl', ''), # 这个字段拿不到 ) ) # 还没有发链接的接口, 暂时还需要自己构造appmsg, 先用WeChatAppMsg。 - message_list.append( - platform_message.WeChatAppMsg( - app_msg=ET.tostring(appmsg, encoding='unicode') - ) - ) + message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg, encoding='unicode'))) except Exception as e: - self.logger.error(f"解析链接消息失败: {str(e)}") + self.logger.error(f'解析链接消息失败: {str(e)}') return platform_message.MessageChain(message_list) async def _handler_compound_mini_program( - self, - message: dict, - xml_data: ET.Element + self, message: dict, xml_data: ET.Element ) -> platform_message.MessageChain: """处理小程序消息(如小程序卡片、服务通知)""" xml_data_str = ET.tostring(xml_data, encoding='unicode') - return platform_message.MessageChain([ - platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str) - ]) + return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)]) - async def _handler_default( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理未知消息类型""" if message: - msg_type = message["msg_type"] + msg_type = message['msg_type'] else: - msg_type = "" - return platform_message.MessageChain([ - platform_message.Unknown(text=f"[未知消息类型 msg_type:{msg_type}]") - ]) + msg_type = '' + return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')]) def _handler_compound_unsupported( - self, - message: dict, - xml_data: str, - text: Optional[str] = None + self, message: dict, xml_data: str, text: Optional[str] = None ) -> platform_message.MessageChain: """处理未支持复合消息类型(msg_type=49)子类型""" if not text: - text = f"[xml_data={xml_data}]" + text = f'[xml_data={xml_data}]' content_list = [] - content_list.append( - platform_message.Unknown(text=f"[处理未支持复合消息类型[msg_type=49]|{text}")) + content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}')) return platform_message.MessageChain(content_list) @@ -432,7 +369,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): ats_bot = False try: to_user_name = message['to_user_name']['str'] # 接收方: 所属微信的wxid - raw_content = message["content"]["str"] # 原始消息内容 + raw_content = message['content']['str'] # 原始消息内容 content_no_prefix, _ = self._extract_content_and_sender(raw_content) # 直接艾特机器人(这个有bug,当被引用的消息里面有@bot,会套娃 # ats_bot = ats_bot or (f"@{bot_account_id}" in content_no_prefix) @@ -443,7 +380,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): msg_source = message.get('msg_source', '') or '' if len(msg_source) > 0: msg_source_data = ET.fromstring(msg_source) - at_user_list = msg_source_data.findtext("atuserlist") or "" + at_user_list = msg_source_data.findtext('atuserlist') or '' ats_bot = ats_bot or (to_user_name in at_user_list) # 引用bot if message.get('msg_type', 0) == 49: @@ -454,7 +391,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者 ats_bot = ats_bot or (quote_id == tousername) except Exception as e: - self.logger.error(f"_ats_bot got except: {e}") + self.logger.error(f'_ats_bot got except: {e}') finally: return ats_bot @@ -463,47 +400,41 @@ class WeChatPadMessageConverter(adapter.MessageConverter): try: # 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉 # add: 有些用户的wxid不是上述格式。换成user_name: - regex = re.compile(r"^[a-zA-Z0-9_\-]{5,20}:") - line_split = raw_content.split("\n") + regex = re.compile(r'^[a-zA-Z0-9_\-]{5,20}:') + line_split = raw_content.split('\n') if len(line_split) > 0 and regex.match(line_split[0]): - raw_content = "\n".join(line_split[1:]) - sender_id = line_split[0].strip(":") + raw_content = '\n'.join(line_split[1:]) + sender_id = line_split[0].strip(':') return raw_content, sender_id except Exception as e: - self.logger.error(f"_extract_content_and_sender got except: {e}") + self.logger.error(f'_extract_content_and_sender got except: {e}') finally: return raw_content, None # 是否是群消息 def _is_group_message(self, message: dict) -> bool: from_user_name = message['from_user_name']['str'] - return from_user_name.endswith("@chatroom") + return from_user_name.endswith('@chatroom') class WeChatPadEventConverter(adapter.EventConverter): - def __init__(self, config: dict): self.config = config self.message_converter = WeChatPadMessageConverter(config) - self.logger = logging.getLogger("WeChatPadEventConverter") - + self.logger = logging.getLogger('WeChatPadEventConverter') + @staticmethod - async def yiri2target( - event: platform_events.MessageEvent - ) -> dict: + async def yiri2target(event: platform_events.MessageEvent) -> dict: pass - async def target2yiri( - self, - event: dict, - bot_account_id: str - ) -> platform_events.MessageEvent: - + async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent: # 排除公众号以及微信团队消息 - if event['from_user_name']['str'].startswith('gh_') \ - or event['from_user_name']['str']=='weixin'\ - or event['from_user_name']['str'] == "newsapp"\ - or event['from_user_name']['str'] == self.config["wxid"]: + if ( + event['from_user_name']['str'].startswith('gh_') + or event['from_user_name']['str'] == 'weixin' + or event['from_user_name']['str'] == 'newsapp' + or event['from_user_name']['str'] == self.config['wxid'] + ): return None message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id) @@ -512,7 +443,7 @@ class WeChatPadEventConverter(adapter.EventConverter): if '@chatroom' in event['from_user_name']['str']: # 找出开头的 wxid_ 字符串,以:结尾 - sender_wxid = event['content']['str'].split(":")[0] + sender_wxid = event['content']['str'].split(':')[0] return platform_events.GroupMessage( sender=platform_entities.GroupMember( @@ -524,13 +455,13 @@ class WeChatPadEventConverter(adapter.EventConverter): name=event['from_user_name']['str'], permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, ), message_chain=message_chain, - time=event["create_time"], + time=event['create_time'], source_platform_object=event, ) else: @@ -541,13 +472,13 @@ class WeChatPadEventConverter(adapter.EventConverter): remark='', ), message_chain=message_chain, - time=event["create_time"], + time=event['create_time'], source_platform_object=event, ) class WeChatPadAdapter(adapter.MessagePlatformAdapter): - name: str = "WeChatPad" # 定义适配器名称 + name: str = 'WeChatPad' # 定义适配器名称 bot: WeChatPadClient quart_app: quart.Quart @@ -580,27 +511,21 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # self.ap.logger.debug(f"Gewechat callback event: {data}") # print(data) - try: event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) - except Exception as e: - await self.logger.error(f"Error in wechatpad callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wechatpad callback: {traceback.format_exc()}') if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) return 'ok' - - async def _handle_message( - self, - message: platform_message.MessageChain, - target_id: str - ): + async def _handle_message(self, message: platform_message.MessageChain, target_id: str): """统一消息处理核心逻辑""" content_list = await self.message_converter.yiri2target(message) # print(content_list) - at_targets = [item["target"] for item in content_list if item["type"] == "at"] + at_targets = [item['target'] for item in content_list if item['type'] == 'at'] # print(at_targets) # 处理@逻辑 at_targets = at_targets or [] @@ -608,7 +533,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if at_targets: member_info = self.bot.get_chatroom_member_detail( target_id, - )["Data"]["member_data"]["chatroom_member_list"] + )['Data']['member_data']['chatroom_member_list'] # 处理消息组件 for msg in content_list: @@ -616,63 +541,51 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if msg['type'] == 'text' and at_targets: at_nick_name_list = [] for member in member_info: - if member["user_name"] in at_targets: + if member['user_name'] in at_targets: at_nick_name_list.append(f'@{member["nick_name"]}') msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}' # 统一消息派发 handler_map = { 'text': lambda msg: self.bot.send_text_message( - to_wxid=target_id, - message=msg['content'], - ats=at_targets + to_wxid=target_id, message=msg['content'], ats=at_targets ), 'image': lambda msg: self.bot.send_image_message( - to_wxid=target_id, - img_url=msg["image"], - ats = at_targets + to_wxid=target_id, img_url=msg['image'], ats=at_targets ), 'WeChatEmoji': lambda msg: self.bot.send_emoji_message( - to_wxid=target_id, - emoji_md5=msg['emoji_md5'], - emoji_size=msg['emoji_size'] + to_wxid=target_id, emoji_md5=msg['emoji_md5'], emoji_size=msg['emoji_size'] ), - 'voice': lambda msg: self.bot.send_voice_message( to_wxid=target_id, voice_data=msg['data'], - voice_duration=msg["duration"], - voice_forma=msg["forma"], + voice_duration=msg['duration'], + voice_forma=msg['forma'], ), 'WeChatAppMsg': lambda msg: self.bot.send_app_message( to_wxid=target_id, app_message=msg['app_msg'], type=0, ), - 'at': lambda msg: None + 'at': lambda msg: None, } if handler := handler_map.get(msg['type']): handler(msg) # self.ap.logger.warning(f"未处理的消息类型: {ret}") else: - self.ap.logger.warning(f"未处理的消息类型: {msg['type']}") + self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') continue - async def send_message( - self, - target_type: str, - target_id: str, - message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): """主动发送消息""" return await self._handle_message(message, target_id) async def reply_message( - self, - message_source: platform_events.MessageEvent, - message: platform_message.MessageChain, - quote_origin: bool = False + self, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, + quote_origin: bool = False, ): """回复消息""" if message_source.source_platform_object: @@ -683,58 +596,49 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): pass def register_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners[event_type] = callback def unregister_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): pass async def run_async(self): - - if not self.config["admin_key"] and not self.config["token"]: - raise RuntimeError("无wechatpad管理密匙,请填入配置文件后重启") + if not self.config['admin_key'] and not self.config['token']: + raise RuntimeError('无wechatpad管理密匙,请填入配置文件后重启') else: - if self.config["token"]: - self.bot = WeChatPadClient( - self.config['wechatpad_url'], - self.config["token"] - ) + if self.config['token']: + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) data = self.bot.get_login_status() self.ap.logger.info(data) - if data["Code"] == 300 and data["Text"] == "你已退出微信": + if data['Code'] == 300 and data['Text'] == '你已退出微信': response = requests.post( - f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", - json={"Count": 1, "Days": 365} + f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', + json={'Count': 1, 'Days': 365}, ) if response.status_code != 200: - raise Exception(f"获取token失败: {response.text}") - self.config["token"] = response.json()["Data"][0] + raise Exception(f'获取token失败: {response.text}') + self.config['token'] = response.json()['Data'][0] - elif not self.config["token"]: + elif not self.config['token']: response = requests.post( - f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", - json={"Count": 1, "Days": 365} + f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', + json={'Count': 1, 'Days': 365}, ) if response.status_code != 200: - raise Exception(f"获取token失败: {response.text}") - self.config["token"] = response.json()["Data"][0] + raise Exception(f'获取token失败: {response.text}') + self.config['token'] = response.json()['Data'][0] - self.bot = WeChatPadClient( - self.config['wechatpad_url'], - self.config["token"], - logger=self.logger - ) - self.ap.logger.info(self.config["token"]) + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger) + self.ap.logger.info(self.config['token']) thread_1 = threading.Event() - def wechat_login_process(): # 不登录,这些先注释掉,避免登陆态尝试拉qrcode。 # login_data =self.bot.get_login_qr() @@ -742,67 +646,54 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # url = login_data['Data']["QrCodeUrl"] # self.ap.logger.info(login_data) - - profile =self.bot.get_profile() + profile = self.bot.get_profile() self.ap.logger.info(profile) - self.bot_account_id = profile["Data"]["userInfo"]["nickName"]["str"] - self.config["wxid"] = profile["Data"]["userInfo"]["userName"]["str"] + self.bot_account_id = profile['Data']['userInfo']['nickName']['str'] + self.config['wxid'] = profile['Data']['userInfo']['userName']['str'] thread_1.set() - # asyncio.create_task(wechat_login_process) threading.Thread(target=wechat_login_process).start() def connect_websocket_sync() -> None: - thread_1.wait() - uri = f"{self.config['wechatpad_ws']}/GetSyncMsg?key={self.config['token']}" - self.ap.logger.info(f"Connecting to WebSocket: {uri}") + uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}' + self.ap.logger.info(f'Connecting to WebSocket: {uri}') + def on_message(ws, message): try: data = json.loads(message) - self.ap.logger.debug(f"Received message: {data}") + self.ap.logger.debug(f'Received message: {data}') # 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法 asyncio.run(self.ws_message(data)) except json.JSONDecodeError: - self.ap.logger.error(f"Non-JSON message: {message[:100]}...") + self.ap.logger.error(f'Non-JSON message: {message[:100]}...') def on_error(ws, error): - self.ap.logger.error(f"WebSocket error: {str(error)[:200]}") + self.ap.logger.error(f'WebSocket error: {str(error)[:200]}') def on_close(ws, close_status_code, close_msg): - self.ap.logger.info("WebSocket closed, reconnecting...") + self.ap.logger.info('WebSocket closed, reconnecting...') time.sleep(5) connect_websocket_sync() # 自动重连 def on_open(ws): - self.ap.logger.info("WebSocket connected successfully!") + self.ap.logger.info('WebSocket connected successfully!') ws = websocket.WebSocketApp( - uri, - on_message=on_message, - on_error=on_error, - on_close=on_close, - on_open=on_open - ) - ws.run_forever( - ping_interval=60, - ping_timeout=20 + uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open ) + ws.run_forever(ping_interval=60, ping_timeout=20) # 直接调用同步版本(会阻塞) # connect_websocket_sync() # 这行代码会在WebSocket连接断开后才会执行 # self.ap.logger.info("WebSocket client thread started") - thread = threading.Thread( - target=connect_websocket_sync, - name="WebSocketClientThread", - daemon=True - ) + thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True) thread.start() - self.ap.logger.info("WebSocket client thread started") + self.ap.logger.info('WebSocket client thread started') async def kill(self) -> bool: pass diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index f1cc677e..7be05a85 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -157,7 +157,7 @@ class WecomAdapter(adapter.MessagePlatformAdapter): token=config['token'], EncodingAESKey=config['EncodingAESKey'], contacts_secret=config['contacts_secret'], - logger=self.logger + logger=self.logger, ) async def reply_message( @@ -201,8 +201,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in wecom callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/platform/sources/wecomcs.py b/pkg/platform/sources/wecomcs.py index aab8d394..da84ac6d 100644 --- a/pkg/platform/sources/wecomcs.py +++ b/pkg/platform/sources/wecomcs.py @@ -145,7 +145,7 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): secret=config['secret'], token=config['token'], EncodingAESKey=config['EncodingAESKey'], - logger=self.logger + logger=self.logger, ) async def reply_message( @@ -178,8 +178,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in wecomcs callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py index a8c35883..35a52453 100644 --- a/pkg/rag/knowledge/services/database.py +++ b/pkg/rag/knowledge/services/database.py @@ -1,19 +1,20 @@ from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary from sqlalchemy.orm import declarative_base, sessionmaker, relationship from datetime import datetime -import numpy as np # 用于处理从LargeBinary转换回来的embedding Base = declarative_base() + class KnowledgeBase(Base): __tablename__ = 'kb' id = Column(Integer, primary_key=True, index=True) name = Column(String, index=True) description = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) - embedding_model = Column(String, default="") # 默认嵌入模型 + embedding_model = Column(String, default='') # 默认嵌入模型 top_k = Column(Integer, default=5) # 默认返回的top_k数量 - files = relationship("File", back_populates="knowledge_base") + files = relationship('File', back_populates='knowledge_base') + class File(Base): __tablename__ = 'file' @@ -24,8 +25,9 @@ class File(Base): created_at = Column(DateTime, default=datetime.utcnow) file_type = Column(String) status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误 - knowledge_base = relationship("KnowledgeBase", back_populates="files") - chunks = relationship("Chunk", back_populates="file") + knowledge_base = relationship('KnowledgeBase', back_populates='files') + chunks = relationship('Chunk', back_populates='file') + class Chunk(Base): __tablename__ = 'chunks' @@ -33,26 +35,30 @@ class Chunk(Base): file_id = Column(Integer, ForeignKey('file.id')) text = Column(Text) - file = relationship("File", back_populates="chunks") - vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one + file = relationship('File', back_populates='chunks') + vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one + class Vector(Base): __tablename__ = 'vectors' id = Column(Integer, primary_key=True, index=True) chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) - embedding = Column(LargeBinary) # Store embeddings as binary + embedding = Column(LargeBinary) # Store embeddings as binary + + chunk = relationship('Chunk', back_populates='vector') - chunk = relationship("Chunk", back_populates="vector") # 数据库连接 -DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL -engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {}) +DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL +engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {}) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + # 创建所有表 (可以在应用启动时执行一次) def create_db_and_tables(): Base.metadata.create_all(bind=engine) - print("Database tables created/checked.") + print('Database tables created/checked.') + # 定义嵌入维度(请根据你实际使用的模型调整) -EMBEDDING_DIM = 1024 \ No newline at end of file +EMBEDDING_DIM = 1024 diff --git a/pkg/rag/knowledge/services/embedding_models.py b/pkg/rag/knowledge/services/embedding_models.py index a6ce73ae..7301d640 100644 --- a/pkg/rag/knowledge/services/embedding_models.py +++ b/pkg/rag/knowledge/services/embedding_models.py @@ -1,14 +1,15 @@ # services/embedding_models.py import os -from typing import Dict, Any, List, Type, Optional +from typing import Dict, Any, List import logging -import aiohttp # Import aiohttp for asynchronous requests +import aiohttp # Import aiohttp for asynchronous requests import asyncio from sentence_transformers import SentenceTransformer logger = logging.getLogger(__name__) + # Base class for all embedding models class BaseEmbeddingModel: def __init__(self, model_name: str): @@ -27,9 +28,10 @@ class BaseEmbeddingModel: def embedding_dimension(self) -> int: """Returns the embedding dimension of the model.""" if self._embedding_dimension is None: - raise NotImplementedError("Embedding dimension not set for this model.") + raise NotImplementedError('Embedding dimension not set for this model.') return self._embedding_dimension - + + class EmbeddingModelFactory: @staticmethod def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel: @@ -39,26 +41,29 @@ class EmbeddingModelFactory: """ if model_name_key not in EMBEDDING_MODEL_CONFIGS: raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.") - + config = EMBEDDING_MODEL_CONFIGS[model_name_key] - - if config['type'] == "third_party_api": + + if config['type'] == 'third_party_api': required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension'] if not all(key in config for key in required_keys): - raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}") - + raise ValueError( + f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}" + ) + # Retrieve model_name from config if it differs from model_name_key # Some APIs expect a specific 'model' value in the payload that might be different from the key - api_model_name = config.get('model_name', model_name_key) + api_model_name = config.get('model_name', model_name_key) return ThirdPartyAPIEmbeddingModel( - model_name=api_model_name, # Use the model_name from config or the key + model_name=api_model_name, # Use the model_name from config or the key api_endpoint=config['api_endpoint'], headers=config['headers'], payload_template=config['payload_template'], - embedding_dimension=config['embedding_dimension'] + embedding_dimension=config['embedding_dimension'], ) + class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): def __init__(self, model_name: str): super().__init__(model_name) @@ -68,9 +73,11 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): # if not run in a separate thread/process, but this keeps the API consistent. self.model = SentenceTransformer(model_name) self._embedding_dimension = self.model.get_sentence_embedding_dimension() - logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}") + logger.info( + f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}" + ) except Exception as e: - logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}") + logger.error(f'Failed to load SentenceTransformer model {model_name}: {e}') raise async def embed_documents(self, texts: List[str]) -> List[List[float]]: @@ -84,14 +91,23 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): - def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int): + def __init__( + self, + model_name: str, + api_endpoint: str, + headers: Dict[str, str], + payload_template: Dict[str, Any], + embedding_dimension: int, + ): super().__init__(model_name) self.api_endpoint = api_endpoint self.headers = headers self.payload_template = payload_template self._embedding_dimension = embedding_dimension - self.session = None # aiohttp client session will be initialized on first use or in a context manager - logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}") + self.session = None # aiohttp client session will be initialized on first use or in a context manager + logger.info( + f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}" + ) async def _get_session(self): """Lazily create or return the aiohttp client session.""" @@ -104,7 +120,7 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): if self.session and not self.session.closed: await self.session.close() self.session = None - logger.info(f"Closed aiohttp session for model {self.model_name}") + logger.info(f'Closed aiohttp session for model {self.model_name}') async def embed_documents(self, texts: List[str]) -> List[List[float]]: """Asynchronously embeds a list of texts using the third-party API.""" @@ -118,10 +134,10 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): elif 'texts' in payload: payload['texts'] = [text] else: - raise ValueError("Payload template does not contain expected text input key.") + raise ValueError('Payload template does not contain expected text input key.') tasks.append(self._make_api_request(session, payload)) - + results = await asyncio.gather(*tasks, return_exceptions=True) for i, res in enumerate(results): @@ -131,93 +147,92 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): # - Append None or an empty list # - Re-raise the exception to stop processing # - Log and skip, then continue - embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure + embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure else: embeddings.append(res) - + return embeddings async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]: """Helper to make an asynchronous API request and extract embedding.""" try: async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response: - response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx) + response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx) api_response = await response.json() - + # Adjust this based on your API's actual response structure - if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]: - embedding = api_response["data"][0]["embedding"] + if 'data' in api_response and len(api_response['data']) > 0 and 'embedding' in api_response['data'][0]: + embedding = api_response['data'][0]['embedding'] if len(embedding) != self.embedding_dimension: - logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + logger.warning( + f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.' + ) return embedding - elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]: - embedding = api_response["embeddings"][0] + elif ( + 'embeddings' in api_response + and isinstance(api_response['embeddings'], list) + and api_response['embeddings'] + ): + embedding = api_response['embeddings'][0] if len(embedding) != self.embedding_dimension: - logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + logger.warning( + f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.' + ) return embedding else: - raise ValueError(f"Unexpected API response structure: {api_response}") + raise ValueError(f'Unexpected API response structure: {api_response}') except aiohttp.ClientError as e: - raise ConnectionError(f"API request failed: {e}") from e + raise ConnectionError(f'API request failed: {e}') from e except ValueError as e: - raise ValueError(f"Error processing API response: {e}") from e - + raise ValueError(f'Error processing API response: {e}') from e async def embed_query(self, text: str) -> List[float]: """Asynchronously embeds a single query text.""" results = await self.embed_documents([text]) if results: return results[0] - return [] # Or raise an error if embedding a query must always succeed + return [] # Or raise an error if embedding a query must always succeed + # --- Embedding Model Configuration --- EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { - "MiniLM": { # Example for a local Sentence Transformer model - "type": "sentence_transformer", - "model_name": "sentence-transformers/all-MiniLM-L6-v2" + 'MiniLM': { # Example for a local Sentence Transformer model + 'type': 'sentence_transformer', + 'model_name': 'sentence-transformers/all-MiniLM-L6-v2', }, - "bge-m3": { # Example for a third-party API model - "type": "third_party_api", - "model_name": "bge-m3", - "api_endpoint": "https://api.qhaigc.net/v1/embeddings", - "headers": { - "Content-Type": "application/json", - "Authorization": f"Bearer {os.getenv('rag_api_key')}" - }, - "payload_template": { - "model": "bge-m3", - "input": "" - }, - "embedding_dimension": 1024 + 'bge-m3': { # Example for a third-party API model + 'type': 'third_party_api', + 'model_name': 'bge-m3', + 'api_endpoint': 'https://api.qhaigc.net/v1/embeddings', + 'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("rag_api_key")}'}, + 'payload_template': {'model': 'bge-m3', 'input': ''}, + 'embedding_dimension': 1024, }, - "OpenAI-Ada-002": { - "type": "third_party_api", - "model_name": "text-embedding-ada-002", - "api_endpoint": "https://api.openai.com/v1/embeddings", - "headers": { - "Content-Type": "application/json", - "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set + 'OpenAI-Ada-002': { + 'type': 'third_party_api', + 'model_name': 'text-embedding-ada-002', + 'api_endpoint': 'https://api.openai.com/v1/embeddings', + 'headers': { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}', # Ensure OPENAI_API_KEY is set }, - "payload_template": { - "model": "text-embedding-ada-002", - "input": "" # Text will be injected here + 'payload_template': { + 'model': 'text-embedding-ada-002', + 'input': '', # Text will be injected here }, - "embedding_dimension": 1536 + 'embedding_dimension': 1536, }, - "OpenAI-Embedding-3-Small": { - "type": "third_party_api", - "model_name": "text-embedding-3-small", - "api_endpoint": "https://api.openai.com/v1/embeddings", - "headers": { - "Content-Type": "application/json", - "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" - }, - "payload_template": { - "model": "text-embedding-3-small", - "input": "", + 'OpenAI-Embedding-3-Small': { + 'type': 'third_party_api', + 'model_name': 'text-embedding-3-small', + 'api_endpoint': 'https://api.openai.com/v1/embeddings', + 'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}'}, + 'payload_template': { + 'model': 'text-embedding-3-small', + 'input': '', # "dimensions": 512 # Optional: uncomment if you want a specific output dimension }, - "embedding_dimension": 1536 # Default max dimension for text-embedding-3-small + 'embedding_dimension': 1536, # Default max dimension for text-embedding-3-small }, -} \ No newline at end of file +} diff --git a/pkg/rag/knowledge/services/parser.py b/pkg/rag/knowledge/services/parser.py index 5fa7d589..bea49721 100644 --- a/pkg/rag/knowledge/services/parser.py +++ b/pkg/rag/knowledge/services/parser.py @@ -1,22 +1,21 @@ - import PyPDF2 from docx import Document import pandas as pd -import csv import chardet -from typing import Union, List, Callable, Any +from typing import Union, Callable, Any import logging import markdown from bs4 import BeautifulSoup import ebooklib from ebooklib import epub import re -import asyncio # Import asyncio for async operations +import asyncio # Import asyncio for async operations import os # Configure logging logger = logging.getLogger(__name__) + class FileParser: """ A robust file parser class to extract text content from various document formats. @@ -24,8 +23,8 @@ class FileParser: All core file reading operations are designed to be run synchronously in a thread pool to avoid blocking the asyncio event loop. """ + def __init__(self): - self.logger = logging.getLogger(self.__class__.__name__) async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any: @@ -36,14 +35,14 @@ class FileParser: try: return await asyncio.to_thread(sync_func, *args, **kwargs) except Exception as e: - self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}") + self.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}') raise async def parse(self, file_path: str) -> Union[str, None]: """ Parses the file based on its extension and returns the extracted text content. This is the main asynchronous entry point for parsing. - + Args: file_path (str): The path to the file to be parsed. @@ -51,21 +50,21 @@ class FileParser: Union[str, None]: The extracted text content as a single string, or None if parsing fails. """ if not file_path or not os.path.exists(file_path): - self.logger.error(f"Invalid file path provided: {file_path}") + self.logger.error(f'Invalid file path provided: {file_path}') return None file_extension = file_path.split('.')[-1].lower() parser_method = getattr(self, f'_parse_{file_extension}', None) - + if parser_method is None: - self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}") + self.logger.error(f'Unsupported file format: {file_extension} for file {file_path}') return None - + try: # Pass file_path to the specific parser methods return await parser_method(file_path) except Exception as e: - self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}") + self.logger.error(f'Failed to parse {file_extension} file {file_path}: {e}') return None # --- Helper for reading files with encoding detection --- @@ -74,15 +73,16 @@ class FileParser: Reads a file with automatic encoding detection, ensuring the synchronous file read operation runs in a separate thread. """ + def _read_sync(): with open(file_path, 'rb') as file: raw_data = file.read() detected = chardet.detect(raw_data) encoding = detected['encoding'] or 'utf-8' - + if mode == 'r': return raw_data.decode(encoding, errors='ignore') - return raw_data # For binary mode + return raw_data # For binary mode return await self._run_sync(_read_sync) @@ -90,12 +90,13 @@ class FileParser: async def _parse_txt(self, file_path: str) -> str: """Parses a TXT file and returns its content.""" - self.logger.info(f"Parsing TXT file: {file_path}") + self.logger.info(f'Parsing TXT file: {file_path}') return await self._read_file_content(file_path, mode='r') async def _parse_pdf(self, file_path: str) -> str: """Parses a PDF file and returns its text content.""" - self.logger.info(f"Parsing PDF file: {file_path}") + self.logger.info(f'Parsing PDF file: {file_path}') + def _parse_pdf_sync(): text_content = [] with open(file_path, 'rb') as file: @@ -105,57 +106,69 @@ class FileParser: if text: text_content.append(text) return '\n'.join(text_content) + return await self._run_sync(_parse_pdf_sync) async def _parse_docx(self, file_path: str) -> str: """Parses a DOCX file and returns its text content.""" - self.logger.info(f"Parsing DOCX file: {file_path}") + self.logger.info(f'Parsing DOCX file: {file_path}') + def _parse_docx_sync(): doc = Document(file_path) text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()] return '\n'.join(text_content) + return await self._run_sync(_parse_docx_sync) - + async def _parse_doc(self, file_path: str) -> str: """Handles .doc files, explicitly stating lack of direct support.""" - self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.") - raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.") - + self.logger.warning(f'Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.') + raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.') + async def _parse_xlsx(self, file_path: str) -> str: """Parses an XLSX file, returning text from all sheets.""" - self.logger.info(f"Parsing XLSX file: {file_path}") + self.logger.info(f'Parsing XLSX file: {file_path}') + def _parse_xlsx_sync(): excel_file = pd.ExcelFile(file_path) all_sheet_content = [] for sheet_name in excel_file.sheet_names: df = pd.read_excel(file_path, sheet_name=sheet_name) - sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n" + sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n' all_sheet_content.append(sheet_text) return '\n'.join(all_sheet_content) + return await self._run_sync(_parse_xlsx_sync) - + async def _parse_csv(self, file_path: str) -> str: """Parses a CSV file and returns its content as a string.""" - self.logger.info(f"Parsing CSV file: {file_path}") + self.logger.info(f'Parsing CSV file: {file_path}') + def _parse_csv_sync(): # pd.read_csv can often detect encoding, but explicit detection is safer - raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function + raw_data = self._read_file_content( + file_path, mode='rb' + ) # Note: this will need to be await outside this sync function + _ = raw_data # For simplicity, we'll let pandas handle encoding internally after a raw read. # A more robust solution might pass encoding directly to pd.read_csv after detection. detected = chardet.detect(open(file_path, 'rb').read()) encoding = detected['encoding'] or 'utf-8' df = pd.read_csv(file_path, encoding=encoding) return df.to_string(index=False) + return await self._run_sync(_parse_csv_sync) - + async def _parse_markdown(self, file_path: str) -> str: """Parses a Markdown file, converting it to structured plain text.""" - self.logger.info(f"Parsing Markdown file: {file_path}") + self.logger.info(f'Parsing Markdown file: {file_path}') + def _parse_markdown_sync(): - md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function + md_content = self._read_file_content( + file_path, mode='r' + ) # This is a synchronous call within a sync function html_content = markdown.markdown( - md_content, - extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code'] + md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code'] ) soup = BeautifulSoup(html_content, 'html.parser') text_parts = [] @@ -169,13 +182,13 @@ class FileParser: text_parts.append(text) elif element.name in ['ul', 'ol']: for li in element.find_all('li'): - text_parts.append(f"* {li.get_text().strip()}") + text_parts.append(f'* {li.get_text().strip()}') elif element.name == 'pre': code_block = element.get_text().strip() if code_block: - text_parts.append(f"```\n{code_block}\n```") + text_parts.append(f'```\n{code_block}\n```') elif element.name == 'table': - table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper if table_str: text_parts.append(table_str) elif element.name: @@ -184,15 +197,17 @@ class FileParser: text_parts.append(text) cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) return cleaned_text.strip() + return await self._run_sync(_parse_markdown_sync) async def _parse_html(self, file_path: str) -> str: """Parses an HTML file, extracting structured plain text.""" - self.logger.info(f"Parsing HTML file: {file_path}") + self.logger.info(f'Parsing HTML file: {file_path}') + def _parse_html_sync(): - html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function + html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function soup = BeautifulSoup(html_content, 'html.parser') - for script_or_style in soup(["script", "style"]): + for script_or_style in soup(['script', 'style']): script_or_style.decompose() text_parts = [] for element in soup.body.children if soup.body else soup.children: @@ -207,9 +222,9 @@ class FileParser: for li in element.find_all('li'): text = li.get_text().strip() if text: - text_parts.append(f"* {text}") + text_parts.append(f'* {text}') elif element.name == 'table': - table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper if table_str: text_parts.append(table_str) elif element.name: @@ -218,39 +233,42 @@ class FileParser: text_parts.append(text) cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) return cleaned_text.strip() + return await self._run_sync(_parse_html_sync) - + async def _parse_epub(self, file_path: str) -> str: """Parses an EPUB file, extracting metadata and content.""" - self.logger.info(f"Parsing EPUB file: {file_path}") + self.logger.info(f'Parsing EPUB file: {file_path}') + def _parse_epub_sync(): book = epub.read_epub(file_path) text_content = [] title_meta = book.get_metadata('DC', 'title') if title_meta: - text_content.append(f"Title: {title_meta[0][0]}") + text_content.append(f'Title: {title_meta[0][0]}') creator_meta = book.get_metadata('DC', 'creator') if creator_meta: - text_content.append(f"Author: {creator_meta[0][0]}") + text_content.append(f'Author: {creator_meta[0][0]}') date_meta = book.get_metadata('DC', 'date') if date_meta: - text_content.append(f"Publish Date: {date_meta[0][0]}") + text_content.append(f'Publish Date: {date_meta[0][0]}') toc = book.get_toc() if toc: - text_content.append("\n--- Table of Contents ---") - self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper - text_content.append("--- End of Table of Contents ---\n") + text_content.append('\n--- Table of Contents ---') + self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper + text_content.append('--- End of Table of Contents ---\n') for item in book.get_items(): if item.get_type() == ebooklib.ITEM_DOCUMENT: html_content = item.get_content().decode('utf-8', errors='ignore') soup = BeautifulSoup(html_content, 'html.parser') - for junk in soup(["script", "style", "nav", "header", "footer"]): + for junk in soup(['script', 'style', 'nav', 'header', 'footer']): junk.decompose() text = soup.get_text(separator='\n', strip=True) text = re.sub(r'\n\s*\n', '\n\n', text) if text: text_content.append(text) return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip() + return await self._run_sync(_parse_epub_sync) def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int): @@ -259,10 +277,10 @@ class FileParser: for item in toc_list: if isinstance(item, tuple): chapter, subchapters = item - text_content.append(f"{indent}- {chapter.title}") + text_content.append(f'{indent}- {chapter.title}') self._add_toc_items_sync(subchapters, text_content, level + 1) else: - text_content.append(f"{indent}- {item.title}") + text_content.append(f'{indent}- {item.title}') def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str: """Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous).""" @@ -272,17 +290,17 @@ class FileParser: cells = [td.get_text().strip() for td in tr.find_all('td')] if cells: rows.append(cells) - + if not headers and not rows: - return "" + return '' table_lines = [] if headers: table_lines.append(' | '.join(headers)) table_lines.append(' | '.join(['---'] * len(headers))) - + for row_cells in rows: padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells table_lines.append(' | '.join(padded_cells)) - - return '\n'.join(table_lines) \ No newline at end of file + + return '\n'.join(table_lines) diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py index 4da81eb1..f563f9b3 100644 --- a/pkg/rag/knowledge/services/retriever.py +++ b/pkg/rag/knowledge/services/retriever.py @@ -1,7 +1,6 @@ # services/retriever.py -import asyncio import logging -import numpy as np # Make sure numpy is imported +import numpy as np # Make sure numpy is imported from typing import List, Dict, Any from sqlalchemy.orm import Session from pkg.rag.knowledge.services.base_service import BaseService @@ -11,6 +10,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager logger = logging.getLogger(__name__) + class Retriever(BaseService): def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): super().__init__() @@ -22,10 +22,14 @@ class Retriever(BaseService): self.embedding_model: BaseEmbeddingModel = self._load_embedding_model() def _load_embedding_model(self) -> BaseEmbeddingModel: - self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...") + self.logger.info( + f'Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...' + ) try: model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) - self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}") + self.logger.info( + f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}" + ) return model except Exception as e: self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}") @@ -33,43 +37,42 @@ class Retriever(BaseService): async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: if not self.embedding_model: - raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.") + raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.') self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}") query_embedding: List[float] = await self.embedding_model.embed_query(query) query_embedding_np = np.array([query_embedding], dtype=np.float32) - chroma_results = await self._run_sync( - self.chroma_manager.search_sync, - query_embedding_np, k - ) + chroma_results = await self._run_sync(self.chroma_manager.search_sync, query_embedding_np, k) # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' - matched_chroma_ids = chroma_results.get("ids", [[]])[0] - distances = chroma_results.get("distances", [[]])[0] - chroma_metadatas = chroma_results.get("metadatas", [[]])[0] - chroma_documents = chroma_results.get("documents", [[]])[0] + matched_chroma_ids = chroma_results.get('ids', [[]])[0] + distances = chroma_results.get('distances', [[]])[0] + chroma_metadatas = chroma_results.get('metadatas', [[]])[0] + chroma_documents = chroma_results.get('documents', [[]])[0] if not matched_chroma_ids: - self.logger.info("No relevant chunks found in Chroma.") + self.logger.info('No relevant chunks found in Chroma.') return [] db_chunk_ids = [] for metadata in chroma_metadatas: - if "chunk_id" in metadata: - db_chunk_ids.append(metadata["chunk_id"]) + if 'chunk_id' in metadata: + db_chunk_ids.append(metadata['chunk_id']) else: self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.") if not db_chunk_ids: - self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.") + self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.') return [] - self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...") + self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...') chunks_from_db = await self._run_sync( - lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync - db_chunk_ids + lambda cids: self._db_get_chunks_sync( + SessionLocal(), cids + ), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync + db_chunk_ids, ) chunk_map = {chunk.id: chunk for chunk in chunks_from_db} @@ -80,27 +83,29 @@ class Retriever(BaseService): # Ensure original_chunk_id is int for DB lookup original_chunk_id = int(chroma_id.split('_')[-1]) except (ValueError, IndexError): - self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.") + self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.') continue chunk_text_from_chroma = chroma_documents[i] distance = float(distances[i]) - file_id_from_chroma = chroma_metadatas[i].get("file_id") + file_id_from_chroma = chroma_metadatas[i].get('file_id') chunk_from_db = chunk_map.get(original_chunk_id) - results_list.append({ - "chunk_id": original_chunk_id, - "text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma, - "distance": distance, - "file_id": file_id_from_chroma - }) + results_list.append( + { + 'chunk_id': original_chunk_id, + 'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma, + 'distance': distance, + 'file_id': file_id_from_chroma, + } + ) - self.logger.info(f"Retrieved {len(results_list)} chunks.") + self.logger.info(f'Retrieved {len(results_list)} chunks.') return results_list def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]: - self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).") + self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).') chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all() session.close() - return chunks \ No newline at end of file + return chunks From bef0d73e83703e79e31bc178f96edfbb943c8d9b Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 6 Jul 2025 10:25:28 +0800 Subject: [PATCH 17/60] feat: basic definition --- web/src/app/home/knowledge/KBDetailDialog.tsx | 89 +++++++++++++++++++ web/src/app/infra/entities/api/index.ts | 10 +++ web/src/app/infra/http/HttpClient.ts | 21 +++++ web/src/i18n/locales/en-US.ts | 2 + 4 files changed, 122 insertions(+) create mode 100644 web/src/app/home/knowledge/KBDetailDialog.tsx diff --git a/web/src/app/home/knowledge/KBDetailDialog.tsx b/web/src/app/home/knowledge/KBDetailDialog.tsx new file mode 100644 index 00000000..5291dd59 --- /dev/null +++ b/web/src/app/home/knowledge/KBDetailDialog.tsx @@ -0,0 +1,89 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogFooter, +} from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; +import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { KnowledgeBase } from '@/app/infra/entities/api'; + +interface KBDetailDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + kbId?: string; + onFormSubmit: (value: z.infer) => void; + onFormCancel: () => void; + onKbDeleted: () => void; + onNewKbCreated: (kbId: string) => void; +} + +export default function KBDetailDialog({ + open, + onOpenChange, + kbId: propKbId, + onFormSubmit, + onFormCancel, + onKbDeleted, + onNewKbCreated, +}: KBDetailDialogProps) { + const { t } = useTranslation(); + const [kbId, setKbId] = useState(propKbId); + const [activeMenu, setActiveMenu] = useState('metadata'); + const [showDeleteConfirm, setShowDeleteConfirm] = useState(false); + + useEffect(() => { + setKbId(propKbId); + setActiveMenu('metadata'); + }, [propKbId, open]); + + const menu = [ + { + key: 'metadata', + label: t('knowledge.metadata'), + icon: ( + + + + ), + }, + { + key: 'files', + label: t('knowledge.files'), + icon: ( + + + + ), + }, + ]; + + if (!kbId) { + // new kb + return ( + + +
+ + {t('knowledge.newKb')} + +
+
+
+ ); + } +} diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index a44b1991..b230cf9e 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -150,6 +150,16 @@ export interface KnowledgeBase { updated_at?: string; } +export interface ApiRespKnowledgeBaseFiles { + files: KnowledgeBaseFile[]; +} + +export interface KnowledgeBaseFile { + file_id: string; + file_name: string; + status: string; +} + // plugins export interface ApiRespPlugins { plugins: Plugin[]; diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 5c6e0abd..8842b04d 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -37,6 +37,7 @@ import { ApiRespKnowledgeBases, ApiRespKnowledgeBase, KnowledgeBase, + ApiRespKnowledgeBaseFiles, } from '@/app/infra/entities/api'; import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest'; import { GetBotLogsResponse } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse'; @@ -430,6 +431,11 @@ class HttpClient { return this.post(`/api/v1/platform/bots/${botId}/logs`, request); } + // ============ File management API ============ + public uploadDocumentFile(file: File): Promise<{ file_id: string }> { + return this.post('/api/v1/files/documents', file); + } + // ============ Knowledge Base API ============ public getKnowledgeBases(): Promise { return this.get('/api/v1/knowledge/bases'); @@ -443,6 +449,21 @@ class HttpClient { return this.post('/api/v1/knowledge/bases', base); } + public uploadKnowledgeBaseFile( + uuid: string, + file_id: string, + ): Promise { + return this.post(`/api/v1/knowledge/bases/${uuid}/files`, { + file_id, + }); + } + + public getKnowledgeBaseFiles( + uuid: string, + ): Promise { + return this.get(`/api/v1/knowledge/bases/${uuid}/files`); + } + // ============ Plugins API ============ public getPlugins(): Promise { return this.get('/api/v1/plugins'); diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 5596e35f..7a1f79c8 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -233,6 +233,8 @@ const enUS = { knowledge: { title: 'Knowledge', description: 'Configuring knowledge bases for improved LLM responses', + metadata: 'Metadata', + files: 'Files', }, register: { title: 'Initialize LangBot 👋', From ebd8e014c61d7389dd44b16d840abd7160d5854d Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 6 Jul 2025 15:52:53 +0800 Subject: [PATCH 18/60] feat: rag fe framework --- web/src/app/home/bots/BotDetailDialog.tsx | 2 - .../home/bots/components/bot-form/BotForm.tsx | 41 ----- web/src/app/home/bots/page.tsx | 2 +- web/src/app/home/knowledge/KBDetailDialog.tsx | 134 +++++++++++++- .../knowledge/components/kb-card/KBCard.tsx | 2 +- .../knowledge/components/kb-docs/KBDoc.tsx | 0 .../components/kb-form/ChooseEntity.ts | 4 + .../knowledge/components/kb-form/KBForm.tsx | 172 ++++++++++++++++++ web/src/app/home/knowledge/page.tsx | 84 ++++++++- web/src/i18n/locales/en-US.ts | 16 +- 10 files changed, 398 insertions(+), 59 deletions(-) create mode 100644 web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx create mode 100644 web/src/app/home/knowledge/components/kb-form/ChooseEntity.ts create mode 100644 web/src/app/home/knowledge/components/kb-form/KBForm.tsx diff --git a/web/src/app/home/bots/BotDetailDialog.tsx b/web/src/app/home/bots/BotDetailDialog.tsx index 1c4a2403..cad04e7b 100644 --- a/web/src/app/home/bots/BotDetailDialog.tsx +++ b/web/src/app/home/bots/BotDetailDialog.tsx @@ -130,7 +130,6 @@ export default function BotDetailDialog({ onFormCancel={handleFormCancel} onBotDeleted={handleBotDeleted} onNewBotCreated={handleNewBotCreated} - hideButtons={true} /> @@ -202,7 +201,6 @@ export default function BotDetailDialog({ onFormCancel={handleFormCancel} onBotDeleted={handleBotDeleted} onNewBotCreated={handleNewBotCreated} - hideButtons={true} /> )} {activeMenu === 'logs' && botId && ( diff --git a/web/src/app/home/bots/components/bot-form/BotForm.tsx b/web/src/app/home/bots/components/bot-form/BotForm.tsx index 40a902c2..fe36d33b 100644 --- a/web/src/app/home/bots/components/bot-form/BotForm.tsx +++ b/web/src/app/home/bots/components/bot-form/BotForm.tsx @@ -67,14 +67,12 @@ export default function BotForm({ onFormCancel, onBotDeleted, onNewBotCreated, - hideButtons = false, }: { initBotId?: string; onFormSubmit: (value: z.infer>) => void; onFormCancel: () => void; onBotDeleted: () => void; onNewBotCreated: (botId: string) => void; - hideButtons?: boolean; }) { const { t } = useTranslation(); const formSchema = getFormSchema(t); @@ -527,45 +525,6 @@ export default function BotForm({ )} - - {!hideButtons && ( -
-
- {!initBotId && ( - - )} - {initBotId && ( - <> - - - - )} - -
-
- )} diff --git a/web/src/app/home/bots/page.tsx b/web/src/app/home/bots/page.tsx index d4305898..ad130fae 100644 --- a/web/src/app/home/bots/page.tsx +++ b/web/src/app/home/bots/page.tsx @@ -92,7 +92,7 @@ export default function BotConfigPage() { } return ( -
+
void; kbId?: string; + // eslint-disable-next-line @typescript-eslint/no-explicit-any onFormSubmit: (value: z.infer) => void; onFormCancel: () => void; onKbDeleted: () => void; @@ -36,7 +48,7 @@ export default function KBDetailDialog({ const { t } = useTranslation(); const [kbId, setKbId] = useState(propKbId); const [activeMenu, setActiveMenu] = useState('metadata'); - const [showDeleteConfirm, setShowDeleteConfirm] = useState(false); + // const [showDeleteConfirm, setShowDeleteConfirm] = useState(false); useEffect(() => { setKbId(propKbId); @@ -58,8 +70,8 @@ export default function KBDetailDialog({ ), }, { - key: 'files', - label: t('knowledge.files'), + key: 'documents', + label: t('knowledge.documents'), icon: (
- {t('knowledge.newKb')} + {t('knowledge.createKnowledgeBase')} +
+ {activeMenu === 'metadata' && ( + + )} + {activeMenu === 'documents' &&
documents
} +
+ {activeMenu === 'metadata' && ( + +
+ + +
+
+ )} ); } + + return ( + <> + + + + + + + + + {menu.map((item) => ( + + setActiveMenu(item.key)} + > + + {item.icon} + {item.label} + + + + ))} + + + + + +
+ + + {activeMenu === 'metadata' + ? t('knowledge.createKnowledgeBase') + : t('knowledge.editDocument')} + + +
+ {activeMenu === 'metadata' && ( + + )} + {activeMenu === 'documents' &&
documents
} +
+ {activeMenu === 'metadata' && ( + +
+ + + +
+
+ )} +
+
+
+
+ + ); } diff --git a/web/src/app/home/knowledge/components/kb-card/KBCard.tsx b/web/src/app/home/knowledge/components/kb-card/KBCard.tsx index 5d49e738..560b0497 100644 --- a/web/src/app/home/knowledge/components/kb-card/KBCard.tsx +++ b/web/src/app/home/knowledge/components/kb-card/KBCard.tsx @@ -26,7 +26,7 @@ export default function KBCard({ kbCardVO }: { kbCardVO: KnowledgeBaseVO }) {
- {t('knowledge.bases.updateTime')} + {t('knowledge.updateTime')} {kbCardVO.lastUpdatedTimeAgo}
diff --git a/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx b/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx new file mode 100644 index 00000000..e69de29b diff --git a/web/src/app/home/knowledge/components/kb-form/ChooseEntity.ts b/web/src/app/home/knowledge/components/kb-form/ChooseEntity.ts new file mode 100644 index 00000000..54f983e4 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-form/ChooseEntity.ts @@ -0,0 +1,4 @@ +export interface IEmbeddingModelEntity { + label: string; + value: string; +} diff --git a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx new file mode 100644 index 00000000..9ae51656 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx @@ -0,0 +1,172 @@ +import { useEffect, useState } from 'react'; +import { useForm } from 'react-hook-form'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { z } from 'zod'; +import { useTranslation } from 'react-i18next'; +import { Input } from '@/components/ui/input'; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, + FormDescription, +} from '@/components/ui/form'; +import { IEmbeddingModelEntity } from './ChooseEntity'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; + +const getFormSchema = (t: (key: string) => string) => + z.object({ + name: z.string().min(1, { message: t('knowledge.kbNameRequired') }), + description: z + .string() + .min(1, { message: t('knowledge.kbDescriptionRequired') }), + embeddingModelUUID: z + .string() + .min(1, { message: t('knowledge.embeddingModelUUIDRequired') }), + }); + +export default function KBForm({ + initKbId, + onFormSubmit, + onFormCancel, + onKbDeleted, + onNewKbCreated, +}: { + initKbId?: string; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + onFormSubmit: (value: any) => void; + onFormCancel: () => void; + onKbDeleted: () => void; + onNewKbCreated: (kbId: string) => void; +}) { + const { t } = useTranslation(); + const formSchema = getFormSchema(t); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + name: '', + description: t('knowledge.defaultDescription'), + embeddingModelUUID: '', + }, + }); + + const [embeddingModelNameList, setEmbeddingModelNameList] = useState< + IEmbeddingModelEntity[] + >([]); + + useEffect(() => { + getEmbeddingModelNameList(); + }, []); + + const getEmbeddingModelNameList = async () => { + const resp = await httpClient.getProviderEmbeddingModels(); + setEmbeddingModelNameList( + resp.models.map((item) => { + return { + label: item.name, + value: item.uuid, + }; + }), + ); + }; + + return ( + <> +
+ +
+ ( + + + {t('knowledge.kbName')} + * + + + + + + + )} + /> + ( + + + {t('knowledge.kbDescription')} + * + + + + + + + )} + /> + ( + + + {t('knowledge.embeddingModelUUID')} + * + + +
+ +
+
+ + {t('knowledge.embeddingModelDescription')} + + +
+ )} + /> +
+
+ + + ); +} diff --git a/web/src/app/home/knowledge/page.tsx b/web/src/app/home/knowledge/page.tsx index 7ee25eac..99de73d8 100644 --- a/web/src/app/home/knowledge/page.tsx +++ b/web/src/app/home/knowledge/page.tsx @@ -3,32 +3,102 @@ import CreateCardComponent from '@/app/infra/basic-component/create-card-component/CreateCardComponent'; import styles from './knowledgeBase.module.css'; import { useTranslation } from 'react-i18next'; -import { useState } from 'react'; +import { useEffect, useState } from 'react'; import { KnowledgeBaseVO } from '@/app/home/knowledge/components/kb-card/KBCardVO'; import KBCard from '@/app/home/knowledge/components/kb-card/KBCard'; +import KBDetailDialog from '@/app/home/knowledge/KBDetailDialog'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { KnowledgeBase } from '@/app/infra/entities/api'; export default function KnowledgePage() { const { t } = useTranslation(); const [knowledgeBaseList, setKnowledgeBaseList] = useState( [], ); + const [selectedKbId, setSelectedKbId] = useState(''); + const [detailDialogOpen, setDetailDialogOpen] = useState(false); + + useEffect(() => { + getKnowledgeBaseList(); + }, []); + + async function getKnowledgeBaseList() { + const resp = await httpClient.getKnowledgeBases(); + setKnowledgeBaseList( + resp.bases.map((kb: KnowledgeBase) => { + const currentTime = new Date(); + const lastUpdatedTimeAgo = Math.floor( + (currentTime.getTime() - + new Date(kb.updated_at ?? currentTime.getTime()).getTime()) / + 1000 / + 60 / + 60 / + 24, + ); + + const lastUpdatedTimeAgoText = + lastUpdatedTimeAgo > 0 + ? ` ${lastUpdatedTimeAgo} ${t('knowledge.daysAgo')}` + : t('knowledge.today'); + + return new KnowledgeBaseVO({ + id: kb.uuid || '', + name: kb.name, + description: kb.description, + embeddingModelUUID: kb.embedding_model_uuid, + lastUpdatedTimeAgo: lastUpdatedTimeAgoText, + }); + }), + ); + } const handleKBCardClick = (kbId: string) => { - // setIsEditForm(false); - // setModalOpen(true); + setSelectedKbId(kbId); + setDetailDialogOpen(true); + }; + + const handleCreateKBClick = () => { + setSelectedKbId(''); + setDetailDialogOpen(true); + }; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const handleFormSubmit = (value: any) => { + console.log('handleFormSubmit', value); + }; + + const handleFormCancel = () => { + setDetailDialogOpen(false); + }; + + const handleKbDeleted = () => { + getKnowledgeBaseList(); + setDetailDialogOpen(false); + }; + + const handleNewKbCreated = () => { + getKnowledgeBaseList(); + setDetailDialogOpen(false); }; return (
+ +
{ - // setIsEditForm(false); - // setModalOpen(true); - }} + onClick={handleCreateKBClick} /> {knowledgeBaseList.map((kb) => { diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 7a1f79c8..e50e7cc7 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -232,9 +232,23 @@ const enUS = { }, knowledge: { title: 'Knowledge', + createKnowledgeBase: 'Create Knowledge Base', description: 'Configuring knowledge bases for improved LLM responses', metadata: 'Metadata', - files: 'Files', + documents: 'Documents', + kbNameRequired: 'Knowledge base name cannot be empty', + kbDescriptionRequired: 'Knowledge base description cannot be empty', + embeddingModelUUIDRequired: 'Embedding model cannot be empty', + daysAgo: 'days ago', + today: 'Today', + kbName: 'Knowledge Base Name', + kbDescription: 'Knowledge Base Description', + defaultDescription: 'A knowledge base', + embeddingModelUUID: 'Embedding Model', + selectEmbeddingModel: 'Select Embedding Model', + embeddingModelDescription: + 'Used to vectorize the text, you can configure it in the Models page', + updateTime: 'Updated ', }, register: { title: 'Initialize LangBot 👋', From cd2534082698870662df7e0ebb14499e16b4f012 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 6 Jul 2025 16:08:02 +0800 Subject: [PATCH 19/60] perf: en comments --- pkg/utils/constants.py | 2 +- web/src/app/home/knowledge/KBDetailDialog.tsx | 2 +- web/src/i18n/locales/en-US.ts | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index 8c4da3cc..e8193839 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1,7 +1,7 @@ semantic_version = 'v4.0.8' required_database_version = 3 -"""标记本版本所需要的数据库结构版本,用于判断数据库迁移""" +"""Tag the version of the database schema, used to check if the database needs to be migrated""" debug_mode = False diff --git a/web/src/app/home/knowledge/KBDetailDialog.tsx b/web/src/app/home/knowledge/KBDetailDialog.tsx index d79702bc..e3ab4f9d 100644 --- a/web/src/app/home/knowledge/KBDetailDialog.tsx +++ b/web/src/app/home/knowledge/KBDetailDialog.tsx @@ -163,7 +163,7 @@ export default function KBDetailDialog({ {activeMenu === 'metadata' - ? t('knowledge.createKnowledgeBase') + ? t('knowledge.editKnowledgeBase') : t('knowledge.editDocument')} diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index e50e7cc7..ecc43204 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -233,6 +233,8 @@ const enUS = { knowledge: { title: 'Knowledge', createKnowledgeBase: 'Create Knowledge Base', + editKnowledgeBase: 'Edit Knowledge Base', + editDocument: 'Documents', description: 'Configuring knowledge bases for improved LLM responses', metadata: 'Metadata', documents: 'Documents', From ac03a2dceb1bcb7da5e10571630ee85d47079a58 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Wed, 9 Jul 2025 22:09:46 +0800 Subject: [PATCH 20/60] feat: modify the rag.py --- .../http/controller/groups/knowledge_base.py | 72 ++-- pkg/entity/persistence/rag.py | 58 +++ pkg/rag/knowledge/RAG_Manager.py | 354 +++++++++++------- pkg/rag/knowledge/services/database.py | 83 ++-- 4 files changed, 338 insertions(+), 229 deletions(-) create mode 100644 pkg/entity/persistence/rag.py diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge_base.py index e9606a3d..ce391042 100644 --- a/pkg/api/http/controller/groups/knowledge_base.py +++ b/pkg/api/http/controller/groups/knowledge_base.py @@ -1,6 +1,6 @@ import quart from .. import group - +import os # 导入 os 用于文件操作 @group.group_class('knowledge_base', '/api/v1/knowledge/bases') class KnowledgeBaseRouterGroup(group.RouterGroup): @@ -9,8 +9,8 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): return quart.jsonify({'code': code, 'data': data or {}, 'msg': msg}) async def initialize(self) -> None: - @self.route('', methods=['POST', 'GET']) - async def _() -> str: + @self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases') + async def handle_knowledge_bases() -> str: if quart.request.method == 'GET': knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases() bases_list = [ @@ -23,17 +23,17 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): ] return self.success(code=0, data={'bases': bases_list}, msg='ok') + # POST: create a new knowledge base json_data = await quart.request.json knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( json_data.get('name'), json_data.get('description') ) - _ = knowledge_base_uuid - return self.success(code=0, data={}, msg='ok') + return self.success(code=0, data={'uuid': knowledge_base_uuid}, msg='ok') - @self.route('/', methods=['GET', 'DELETE']) - async def _(knowledge_base_uuid: str) -> str: + @self.route('/', methods=['GET', 'DELETE'], endpoint='handle_specific_knowledge_base') + async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid) + knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(int(knowledge_base_uuid)) if knowledge_base is None: return self.http_status(404, -1, 'knowledge base not found') @@ -48,28 +48,42 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): msg='ok', ) elif quart.request.method == 'DELETE': - await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid) + await self.ap.knowledge_base_service.delete_kb_by_id(int(knowledge_base_uuid)) return self.success(code=0, msg='ok') - @self.route('//files', methods=['GET']) - async def _(knowledge_base_uuid: str) -> str: - if quart.request.method == 'GET': - files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) - return self.success( - code=0, - data=[ - { - 'id': file.id, - 'file_name': file.file_name, - 'status': file.status, - } - for file in files - ], - msg='ok', - ) - # delete specific file in knowledge base - @self.route('//files/', methods=['DELETE']) - async def _(knowledge_base_uuid: str, file_id: str) -> str: - await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) + @self.route('//files', methods=['GET'], endpoint='get_knowledge_base_files') + async def get_knowledge_base_files(knowledge_base_uuid: str) -> str: + files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(int(knowledge_base_uuid)) + return self.success( + code=0, + data=[ + { + 'id': file.id, + 'file_name': file.file_name, + 'status': file.status, + } + for file in files + ], + msg='ok', + ) + + + @self.route('//files/', methods=['DELETE'], endpoint='delete_specific_file_in_kb') + async def delete_specific_file_in_kb(file_id: str) -> str: + await self.ap.knowledge_base_service.delete_data_by_file_id(int(file_id)) return self.success(code=0, msg='ok') + + @self.route('//files', methods=['POST'], endpoint='relate_file_with_kb') + async def relate_file_id_with_kb(knowledge_base_uuid:str,file_id: str) -> str: + if 'file' not in quart.request.files: + return self.http_status(400, -1, 'No file part in the request') + + json_data = await quart.request.json + file_id = json_data.get('file_id') + if not file_id: + return self.http_status(400, -1, 'File ID is required') + + # 调用服务层方法将文件与知识库关联 + await self.ap.knowledge_base_service.relate_file_id_with_kb(int(knowledge_base_uuid), int(file_id)) + return self.success(code=0, data={}, msg='ok') \ No newline at end of file diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py new file mode 100644 index 00000000..175720f1 --- /dev/null +++ b/pkg/entity/persistence/rag.py @@ -0,0 +1,58 @@ +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, LargeBinary +from sqlalchemy.orm import declarative_base, sessionmaker +from datetime import datetime +import os + + +Base = declarative_base() +DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./rag_knowledge.db") + + +engine = create_engine( + DATABASE_URL, + connect_args={"check_same_thread": False} +) + + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def create_db_and_tables(): + """Creates all database tables defined in the Base.""" + Base.metadata.create_all(bind=engine) + print("Database tables created or already exist.") + +class KnowledgeBase(Base): + __tablename__ = 'kb' + id = Column(Integer, primary_key=True, index=True) + name = Column(String, index=True) + description = Column(Text) + created_at = Column(DateTime, default=datetime.utcnow) + embedding_model = Column(String, default='') + top_k = Column(Integer, default=5) + + +class File(Base): + __tablename__ = 'file' + id = Column(Integer, primary_key=True, index=True) + kb_id = Column(Integer, nullable=True) + file_name = Column(String) + path = Column(String) + created_at = Column(DateTime, default=datetime.utcnow) + file_type = Column(String) + status = Column(Integer, default=0) + + +class Chunk(Base): + __tablename__ = 'chunks' + id = Column(Integer, primary_key=True, index=True) + file_id = Column(Integer, nullable=True) + + text = Column(Text) + + +class Vector(Base): + __tablename__ = 'vectors' + id = Column(Integer, primary_key=True, index=True) + chunk_id = Column(Integer, nullable=True) + embedding = Column(LargeBinary) diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py index 6ded737a..9675371b 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -1,38 +1,42 @@ -# RAG_Manager class (main class, adjust imports as needed) -from __future__ import annotations # For type hinting in Python 3.7+ +# rag_manager.py +from __future__ import annotations import logging import os import asyncio +import json +import uuid from pkg.rag.knowledge.services.parser import FileParser from pkg.rag.knowledge.services.chunker import Chunker from pkg.rag.knowledge.services.embedder import Embedder from pkg.rag.knowledge.services.retriever import Retriever -from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly +from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager -from pkg.core import app # Adjust the import path as needed - +from pkg.core import app class RAG_Manager: - ap: app.Application - def __init__(self, ap: app.Application,logger: logging.Logger = None): + def __init__(self, ap: app.Application, logger: logging.Logger = None): self.ap = ap self.logger = logger or logging.getLogger(__name__) self.embedding_model_type = None self.embedding_model_name = None self.chroma_manager = None - self.parser = None - self.chunker = None + self.parser = FileParser() + self.chunker = Chunker() self.embedder = None self.retriever = None async def initialize_rag_system(self): + """Initializes the RAG system by creating database tables.""" await asyncio.to_thread(create_db_and_tables) - async def create_specific_model(self, embedding_model_type: str, - embedding_model_name: str): + async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str): + """ + Creates and configures the specific embedding model and ChromaDB manager. + This must be called before performing embedding or retrieval operations. + """ self.embedding_model_type = embedding_model_type self.embedding_model_name = embedding_model_name @@ -47,52 +51,38 @@ class RAG_Manager: raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.") self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}") - - self.parser = FileParser() - self.chunker = Chunker() - # Pass chroma_manager to Embedder and Retriever - self.embedder = Embedder( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name, - chroma_manager=self.chroma_manager # Inject dependency - ) - self.retriever = Retriever( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name, - chroma_manager=self.chroma_manager # Inject dependency - ) - + self.embedder = Embedder(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager) + self.retriever = Retriever(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager) async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5): """ - Creates a new knowledge base with the given name and description. - If a knowledge base with the same name already exists, it returns that one. + Creates a new knowledge base if it doesn't already exist. """ try: - def _get_kb_sync(name): + if not self.embedding_model_type or not kb_name: + raise ValueError("Embedding model type and knowledge base name must be set before creating a knowledge base.") + def _create_kb_sync(): session = SessionLocal() try: - return session.query(KnowledgeBase).filter_by(name=name).first() - finally: - session.close() - - kb = await asyncio.to_thread(_get_kb_sync, kb_name) - - if not kb: - def _add_kb_sync(): - session = SessionLocal() - try: - new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k) + kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() + if not kb: + id = uuid.uuid4().int + new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k,id=id) session.add(new_kb) session.commit() session.refresh(new_kb) - return new_kb - finally: - session.close() - kb = await asyncio.to_thread(_add_kb_sync) - except Exception as e: - self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) - raise + self.logger.info(f"Knowledge Base '{kb_name}' created.") + return new_kb.id + else: + self.logger.info(f"Knowledge Base '{kb_name}' already exists.") + except Exception as e: + session.rollback() + self.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True) + raise + finally: + session.close() + + return await asyncio.to_thread(_create_kb_sync) except Exception as e: self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) raise @@ -108,116 +98,124 @@ class RAG_Manager: return session.query(KnowledgeBase).all() finally: session.close() - - kbs = await asyncio.to_thread(_get_all_kbs_sync) - return kbs + return await asyncio.to_thread(_get_all_kbs_sync) except Exception as e: self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True) return [] - + async def get_knowledge_base_by_id(self, kb_id: int): """ - Retrieves a knowledge base by its ID. + Retrieves a specific knowledge base by its ID. """ try: - def _get_kb_sync(kb_id): + def _get_kb_sync(kb_id_param): session = SessionLocal() try: - return session.query(KnowledgeBase).filter_by(id=kb_id).first() + return session.query(KnowledgeBase).filter_by(id=kb_id_param).first() finally: session.close() - - kb = await asyncio.to_thread(_get_kb_sync, kb_id) - return kb + return await asyncio.to_thread(_get_kb_sync, kb_id) except Exception as e: self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True) return None - + async def get_files_by_knowledge_base(self, kb_id: int): + """ + Retrieves files associated with a specific knowledge base by querying the File table directly. + """ try: - def _get_files_sync(kb_id): + def _get_files_sync(kb_id_param): session = SessionLocal() try: - return session.query(File).filter_by(kb_id=kb_id).all() + return session.query(File).filter_by(kb_id=kb_id_param).all() finally: session.close() - - files = await asyncio.to_thread(_get_files_sync, kb_id) - return files + return await asyncio.to_thread(_get_files_sync, kb_id) except Exception as e: self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True) return [] + async def get_all_files(self): + """ + Retrieves all files stored in the database, regardless of their association + with any specific knowledge base. + """ + try: + def _get_all_files_sync(): + session = SessionLocal() + try: + return session.query(File).all() + finally: + session.close() + return await asyncio.to_thread(_get_all_files_sync) + except Exception as e: + self.logger.error(f"Error retrieving all files: {str(e)}", exc_info=True) + return [] async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"): + """ + Parses, chunks, embeds, and stores data from a given file into the RAG system. + Associates the file with a knowledge base using kb_id in the File table. + """ self.logger.info(f"Starting data storage process for file: {file_path}") + session = SessionLocal() + file_obj = None + try: - def _get_kb_sync(name): - session = SessionLocal() - try: - return session.query(KnowledgeBase).filter_by(name=name).first() - finally: - session.close() - - kb = await asyncio.to_thread(_get_kb_sync, kb_name) - + # 1. 确保知识库存在或创建它 + kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() if not kb: - self.logger.info(f"Knowledge Base '{kb_name}' not found. Creating a new one.") - def _add_kb_sync(): - session = SessionLocal() - try: - new_kb = KnowledgeBase(name=kb_name, description=kb_description) - session.add(new_kb) - session.commit() - session.refresh(new_kb) - return new_kb - finally: - session.close() - kb = await asyncio.to_thread(_add_kb_sync) - self.logger.info(f"Created Knowledge Base: {kb.name} (ID: {kb.id})") + kb = KnowledgeBase(name=kb_name, description=kb_description) + session.add(kb) + session.commit() + session.refresh(kb) + self.logger.info(f"Knowledge Base '{kb_name}' created during store_data.") + else: + self.logger.info(f"Knowledge Base '{kb_name}' already exists.") - def _add_file_sync(kb_id, file_name, path, file_type): - session = SessionLocal() - try: - file = File(kb_id=kb_id, file_name=file_name, path=path, file_type=file_type) - session.add(file) - session.commit() - session.refresh(file) - return file - finally: - session.close() - - file_obj = await asyncio.to_thread(_add_file_sync, kb.id, os.path.basename(file_path), file_path, file_type) - self.logger.info(f"Added file entry: {file_obj.file_name} (ID: {file_obj.id})") - - text = await self.parser.parse(file_path) - if not text: - self.logger.warning(f"File {file_path} parsed to empty content. Skipping chunking and embedding.") - # You might want to delete the file_obj from the DB here if it's empty. - session = SessionLocal() - try: - session.delete(file_obj) - session.commit() - except Exception as del_e: - self.logger.error(f"Failed to delete empty file_obj {file_obj.id}: {del_e}") - finally: - session.close() + # 2. 添加文件记录到数据库,并直接关联 kb_id + file_name = os.path.basename(file_path) + existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first() + if existing_file: + self.logger.warning(f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage.") return + file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type) + session.add(file_obj) + session.commit() + session.refresh(file_obj) + self.logger.info(f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}") + + # 3. 解析文件内容 + text = await self.parser.parse(file_path) + if not text: + self.logger.warning(f"No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.") + session.delete(file_obj) + session.commit() # 提交删除操作 + return + + # 4. 分块并嵌入/存储块 chunks_texts = await self.chunker.chunk(text) - self.logger.info(f"Chunked into {len(chunks_texts)} pieces.") - - # embed_and_store now handles both DB chunk saving and Chroma embedding + self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) - self.logger.info(f"Data storage process completed for file: {file_path}") except Exception as e: + session.rollback() self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True) - # Consider cleaning up partially stored data if an error occurs. - return + if file_obj and file_obj.id: + try: + await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id) + except Exception as chroma_e: + self.logger.warning(f"Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}") + raise + finally: + session.close() async def retrieve_data(self, query: str): + """ + Retrieves relevant data chunks based on a given query using the configured retriever. + """ self.logger.info(f"Starting data retrieval process for query: '{query}'") try: retrieved_chunks = await self.retriever.retrieve(query) @@ -229,60 +227,140 @@ class RAG_Manager: async def delete_data_by_file_id(self, file_id: int): """ - Deletes data associated with a specific file_id from both the relational DB and Chroma. + Deletes all data associated with a specific file ID, including its chunks and vectors, + and the file record itself. """ self.logger.info(f"Starting data deletion process for file_id: {file_id}") session = SessionLocal() try: - # 1. Delete from Chroma + # 1. 从 ChromaDB 删除 embeddings await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) + self.logger.info(f"Deleted embeddings from ChromaDB for file_id: {file_id}") - # 2. Delete chunks from relational DB + # 2. 删除与文件关联的 chunks 记录 chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() for chunk in chunks_to_delete: session.delete(chunk) - self.logger.info(f"Deleted {len(chunks_to_delete)} chunks from relational DB for file_id: {file_id}.") + self.logger.info(f"Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}") - # 3. Delete file entry from relational DB + # 3. 删除文件记录本身 file_to_delete = session.query(File).filter_by(id=file_id).first() if file_to_delete: session.delete(file_to_delete) - self.logger.info(f"Deleted file entry {file_id} from relational DB.") + self.logger.info(f"Deleted file record for file_id: {file_id}") else: - self.logger.warning(f"File entry {file_id} not found in relational DB.") + self.logger.warning(f"File with ID {file_id} not found in database. Skipping deletion of file record.") session.commit() - self.logger.info(f"Data deletion completed for file_id: {file_id}.") + self.logger.info(f"Successfully completed data deletion for file_id: {file_id}") except Exception as e: session.rollback() self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True) + raise finally: session.close() async def delete_kb_by_id(self, kb_id: int): """ - Deletes a knowledge base and all associated files and chunks. + Deletes a knowledge base and all associated files, chunks, and vectors. + This involves querying for associated files and then deleting them. """ self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}") - session = SessionLocal() + session = SessionLocal() # 使用新的会话来获取 KB 和关联文件 + try: - # 1. Get the knowledge base - kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() - if not kb: + kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if not kb_to_delete: self.logger.warning(f"Knowledge Base with ID {kb_id} not found.") return - # 2. Delete all files associated with this knowledge base - files_to_delete = session.query(File).filter_by(kb_id=kb.id).all() - for file in files_to_delete: - await self.delete_data_by_file_id(file.id) + # 获取所有关联的文件,通过 File 表的 kb_id 字段查询 + files_to_delete = session.query(File).filter_by(kb_id=kb_id).all() + + # 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话 + session.close() - # 3. Delete the knowledge base itself - session.delete(kb) + # 遍历删除每个关联文件及其数据 + for file_obj in files_to_delete: + try: + await self.delete_data_by_file_id(file_obj.id) + except Exception as file_del_e: + self.logger.error(f"Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}") + # 记录错误但继续,尝试删除其他文件 + + # 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身 + session = SessionLocal() + try: + # 重新查询,确保对象是当前会话的一部分 + kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if kb_final_delete: + session.delete(kb_final_delete) + session.commit() + self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}") + else: + self.logger.warning(f"Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.") + except Exception as kb_del_e: + session.rollback() + self.logger.error(f"Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}", exc_info=True) + raise + finally: + session.close() + + except Exception as e: + # 如果在最初获取 KB 或文件列表时出错 + if session.is_active: + session.rollback() + self.logger.error(f"Error during overall knowledge base deletion for ID {kb_id}: {str(e)}", exc_info=True) + raise + finally: + if session.is_active: + session.close() + + + + async def get_file_content_by_file_id(self, file_id: str) -> str: + + file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id) + + _, ext = os.path.splitext(file_id.lower()) + ext = ext.lstrip('.') + + try: + text = file_bytes.decode("utf-8") + except UnicodeDecodeError: + return "[非文本文件或编码无法识别]" + + if ext in ["txt", "md", "csv", "log", "py", "html"]: + return text + else: + return f"[未知类型: .{ext}]" + + async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None: + """ + Associates a file with a knowledge base by updating the kb_id in the File table. + """ + self.logger.info(f"Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}") + session = SessionLocal() + try: + # 查询知识库是否存在 + kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first() + if not kb: + self.logger.error(f"Knowledge Base with UUID {knowledge_base_uuid} not found.") + return + + # 更新文件的 kb_id + file_to_update = session.query(File).filter_by(id=file_id).first() + if not file_to_update: + self.logger.error(f"File with ID {file_id} not found.") + return + + file_to_update.kb_id = kb.id session.commit() - self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}") + self.logger.info(f"Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}") except Exception as e: session.rollback() - self.logger.error(f"Error deleting knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + self.logger.error(f"Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}", exc_info=True) finally: session.close() + + diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py index 35a52453..bc5caa10 100644 --- a/pkg/rag/knowledge/services/database.py +++ b/pkg/rag/knowledge/services/database.py @@ -1,64 +1,23 @@ -from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary -from sqlalchemy.orm import declarative_base, sessionmaker, relationship -from datetime import datetime +# 全部迁移过去 -Base = declarative_base() +from pkg.entity.persistence.rag import ( + create_db_and_tables, + SessionLocal, + Base, + engine, + KnowledgeBase, + File, + Chunk, + Vector, +) - -class KnowledgeBase(Base): - __tablename__ = 'kb' - id = Column(Integer, primary_key=True, index=True) - name = Column(String, index=True) - description = Column(Text) - created_at = Column(DateTime, default=datetime.utcnow) - embedding_model = Column(String, default='') # 默认嵌入模型 - top_k = Column(Integer, default=5) # 默认返回的top_k数量 - files = relationship('File', back_populates='knowledge_base') - - -class File(Base): - __tablename__ = 'file' - id = Column(Integer, primary_key=True, index=True) - kb_id = Column(Integer, ForeignKey('kb.id')) - file_name = Column(String) - path = Column(String) - created_at = Column(DateTime, default=datetime.utcnow) - file_type = Column(String) - status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误 - knowledge_base = relationship('KnowledgeBase', back_populates='files') - chunks = relationship('Chunk', back_populates='file') - - -class Chunk(Base): - __tablename__ = 'chunks' - id = Column(Integer, primary_key=True, index=True) - file_id = Column(Integer, ForeignKey('file.id')) - text = Column(Text) - - file = relationship('File', back_populates='chunks') - vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one - - -class Vector(Base): - __tablename__ = 'vectors' - id = Column(Integer, primary_key=True, index=True) - chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) - embedding = Column(LargeBinary) # Store embeddings as binary - - chunk = relationship('Chunk', back_populates='vector') - - -# 数据库连接 -DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL -engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {}) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -# 创建所有表 (可以在应用启动时执行一次) -def create_db_and_tables(): - Base.metadata.create_all(bind=engine) - print('Database tables created/checked.') - - -# 定义嵌入维度(请根据你实际使用的模型调整) -EMBEDDING_DIM = 1024 +__all__ = [ + "create_db_and_tables", + "SessionLocal", + "Base", + "engine", + "KnowledgeBase", + "File", + "Chunk", + "Vector", +] From 75c3ddde19b54418af5c1687bfe71810c3773c6e Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Thu, 10 Jul 2025 16:45:59 +0800 Subject: [PATCH 21/60] perf: definitions --- .../controller/groups/knowledge/__init__.py | 0 .../{knowledge_base.py => knowledge/base.py} | 18 +- pkg/api/http/controller/main.py | 2 + pkg/core/app.py | 7 +- pkg/core/stages/build_app.py | 4 +- pkg/rag/knowledge/{RAG_Manager.py => mgr.py} | 160 +++++++++++------- 6 files changed, 112 insertions(+), 79 deletions(-) create mode 100644 pkg/api/http/controller/groups/knowledge/__init__.py rename pkg/api/http/controller/groups/{knowledge_base.py => knowledge/base.py} (90%) rename pkg/rag/knowledge/{RAG_Manager.py => mgr.py} (67%) diff --git a/pkg/api/http/controller/groups/knowledge/__init__.py b/pkg/api/http/controller/groups/knowledge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge/base.py similarity index 90% rename from pkg/api/http/controller/groups/knowledge_base.py rename to pkg/api/http/controller/groups/knowledge/base.py index ce391042..cf5bb44e 100644 --- a/pkg/api/http/controller/groups/knowledge_base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -1,13 +1,9 @@ import quart -from .. import group -import os # 导入 os 用于文件操作 +from ... import group + @group.group_class('knowledge_base', '/api/v1/knowledge/bases') class KnowledgeBaseRouterGroup(group.RouterGroup): - # 定义成功方法 - def success(self, code=0, data=None, msg: str = 'ok') -> quart.Response: - return quart.jsonify({'code': code, 'data': data or {}, 'msg': msg}) - async def initialize(self) -> None: @self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases') async def handle_knowledge_bases() -> str: @@ -51,7 +47,6 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): await self.ap.knowledge_base_service.delete_kb_by_id(int(knowledge_base_uuid)) return self.success(code=0, msg='ok') - @self.route('//files', methods=['GET'], endpoint='get_knowledge_base_files') async def get_knowledge_base_files(knowledge_base_uuid: str) -> str: files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(int(knowledge_base_uuid)) @@ -68,14 +63,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): msg='ok', ) - @self.route('//files/', methods=['DELETE'], endpoint='delete_specific_file_in_kb') async def delete_specific_file_in_kb(file_id: str) -> str: await self.ap.knowledge_base_service.delete_data_by_file_id(int(file_id)) return self.success(code=0, msg='ok') - + @self.route('//files', methods=['POST'], endpoint='relate_file_with_kb') - async def relate_file_id_with_kb(knowledge_base_uuid:str,file_id: str) -> str: + async def relate_file_id_with_kb(knowledge_base_uuid: str, file_id: str) -> str: if 'file' not in quart.request.files: return self.http_status(400, -1, 'No file part in the request') @@ -83,7 +77,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): file_id = json_data.get('file_id') if not file_id: return self.http_status(400, -1, 'File ID is required') - + # 调用服务层方法将文件与知识库关联 await self.ap.knowledge_base_service.relate_file_id_with_kb(int(knowledge_base_uuid), int(file_id)) - return self.success(code=0, data={}, msg='ok') \ No newline at end of file + return self.success(code=0, data={}, msg='ok') diff --git a/pkg/api/http/controller/main.py b/pkg/api/http/controller/main.py index eb434d88..4eec4e1d 100644 --- a/pkg/api/http/controller/main.py +++ b/pkg/api/http/controller/main.py @@ -14,11 +14,13 @@ from . import group from .groups import provider as groups_provider from .groups import platform as groups_platform from .groups import pipelines as groups_pipelines +from .groups import knowledge as groups_knowledge importutil.import_modules_in_pkg(groups) importutil.import_modules_in_pkg(groups_provider) importutil.import_modules_in_pkg(groups_platform) importutil.import_modules_in_pkg(groups_pipelines) +importutil.import_modules_in_pkg(groups_knowledge) class HTTPController: diff --git a/pkg/core/app.py b/pkg/core/app.py index 2e3c9500..11d25826 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -27,7 +27,7 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities -from pkg.rag.knowledge.RAG_Manager import RAG_Manager +from ..rag.knowledge import mgr as rag_mgr class Application: @@ -48,7 +48,6 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None - # TODO 移动到 pipeline 里 tool_mgr: llm_tool_mgr.ToolManager = None @@ -101,7 +100,6 @@ class Application: storage_mgr: storagemgr.StorageMgr = None - # ========= HTTP Services ========= user_service: user_service.UserService = None @@ -114,8 +112,7 @@ class Application: bot_service: bot_service.BotService = None - knowledge_base_service: RAG_Manager = None - + knowledge_base_service: rag_mgr.RAGManager = None def __init__(self): pass diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index bb86a6d3..ac76c331 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -9,7 +9,7 @@ from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.tools import toolmgr as llm_tool_mgr -from ...rag.knowledge.RAG_Manager import RAG_Manager as knowledge_base_mgr +from ...rag.knowledge import mgr as rag_mgr from ...platform import botmgr as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -102,7 +102,7 @@ class BuildAppStage(stage.BootingStage): embedding_models_service_inst = model_service.EmbeddingModelsService(ap) ap.embedding_models_service = embedding_models_service_inst - knowledge_base_service_inst = knowledge_base_mgr(ap) + knowledge_base_service_inst = rag_mgr.RAGManager(ap) await knowledge_base_service_inst.initialize_rag_system() ap.knowledge_base_service = knowledge_base_service_inst diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/mgr.py similarity index 67% rename from pkg/rag/knowledge/RAG_Manager.py rename to pkg/rag/knowledge/mgr.py index 9675371b..7d1787e0 100644 --- a/pkg/rag/knowledge/RAG_Manager.py +++ b/pkg/rag/knowledge/mgr.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging import os import asyncio -import json import uuid from pkg.rag.knowledge.services.parser import FileParser from pkg.rag.knowledge.services.chunker import Chunker @@ -14,7 +13,8 @@ from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager from pkg.core import app -class RAG_Manager: + +class RAGManager: ap: app.Application def __init__(self, ap: app.Application, logger: logging.Logger = None): @@ -42,32 +42,54 @@ class RAG_Manager: try: model = EmbeddingModelFactory.create_model( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name + model_type=self.embedding_model_type, model_name_key=self.embedding_model_name + ) + self.logger.info( + f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}" ) - self.logger.info(f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}") except Exception as e: - self.logger.critical(f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}") - raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.") + self.logger.critical( + f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}" + ) + raise RuntimeError('Failed to initialize RAG_Manager due to embedding model issues.') - self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}") - self.embedder = Embedder(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager) - self.retriever = Retriever(model_type=self.embedding_model_type, model_name_key=self.embedding_model_name, chroma_manager=self.chroma_manager) + self.chroma_manager = ChromaIndexManager( + collection_name=f'rag_collection_{self.embedding_model_name.replace("-", "_")}' + ) + self.embedder = Embedder( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager, + ) + self.retriever = Retriever( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager, + ) - async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5): + async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = '', top_k: int = 5): """ Creates a new knowledge base if it doesn't already exist. """ try: if not self.embedding_model_type or not kb_name: - raise ValueError("Embedding model type and knowledge base name must be set before creating a knowledge base.") + raise ValueError( + 'Embedding model type and knowledge base name must be set before creating a knowledge base.' + ) + def _create_kb_sync(): session = SessionLocal() try: kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() if not kb: id = uuid.uuid4().int - new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k,id=id) + new_kb = KnowledgeBase( + name=kb_name, + description=kb_description, + embedding_model=embedding_model, + top_k=top_k, + id=id, + ) session.add(new_kb) session.commit() session.refresh(new_kb) @@ -80,7 +102,7 @@ class RAG_Manager: self.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True) raise finally: - session.close() + session.close() return await asyncio.to_thread(_create_kb_sync) except Exception as e: @@ -92,15 +114,17 @@ class RAG_Manager: Retrieves all knowledge bases from the database. """ try: + def _get_all_kbs_sync(): session = SessionLocal() try: return session.query(KnowledgeBase).all() finally: session.close() + return await asyncio.to_thread(_get_all_kbs_sync) except Exception as e: - self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True) + self.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True) return [] async def get_knowledge_base_by_id(self, kb_id: int): @@ -108,15 +132,17 @@ class RAG_Manager: Retrieves a specific knowledge base by its ID. """ try: + def _get_kb_sync(kb_id_param): session = SessionLocal() try: return session.query(KnowledgeBase).filter_by(id=kb_id_param).first() finally: session.close() + return await asyncio.to_thread(_get_kb_sync, kb_id) except Exception as e: - self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + self.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True) return None async def get_files_by_knowledge_base(self, kb_id: int): @@ -124,15 +150,17 @@ class RAG_Manager: Retrieves files associated with a specific knowledge base by querying the File table directly. """ try: + def _get_files_sync(kb_id_param): session = SessionLocal() try: return session.query(File).filter_by(kb_id=kb_id_param).all() finally: session.close() + return await asyncio.to_thread(_get_files_sync, kb_id) except Exception as e: - self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True) + self.logger.error(f'Error retrieving files for knowledge base ID {kb_id}: {str(e)}', exc_info=True) return [] async def get_all_files(self): @@ -141,23 +169,27 @@ class RAG_Manager: with any specific knowledge base. """ try: + def _get_all_files_sync(): session = SessionLocal() try: return session.query(File).all() finally: session.close() + return await asyncio.to_thread(_get_all_files_sync) except Exception as e: - self.logger.error(f"Error retrieving all files: {str(e)}", exc_info=True) + self.logger.error(f'Error retrieving all files: {str(e)}', exc_info=True) return [] - async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"): + async def store_data( + self, file_path: str, kb_name: str, file_type: str, kb_description: str = 'Default knowledge base' + ): """ Parses, chunks, embeds, and stores data from a given file into the RAG system. Associates the file with a knowledge base using kb_id in the File table. """ - self.logger.info(f"Starting data storage process for file: {file_path}") + self.logger.info(f'Starting data storage process for file: {file_path}') session = SessionLocal() file_obj = None @@ -177,37 +209,43 @@ class RAG_Manager: file_name = os.path.basename(file_path) existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first() if existing_file: - self.logger.warning(f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage.") + self.logger.warning( + f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage." + ) return file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type) session.add(file_obj) session.commit() session.refresh(file_obj) - self.logger.info(f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}") + self.logger.info( + f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}" + ) # 3. 解析文件内容 text = await self.parser.parse(file_path) if not text: - self.logger.warning(f"No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.") + self.logger.warning(f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.') session.delete(file_obj) - session.commit() # 提交删除操作 + session.commit() # 提交删除操作 return # 4. 分块并嵌入/存储块 chunks_texts = await self.chunker.chunk(text) self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) - self.logger.info(f"Data storage process completed for file: {file_path}") + self.logger.info(f'Data storage process completed for file: {file_path}') except Exception as e: session.rollback() - self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True) + self.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True) if file_obj and file_obj.id: try: await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id) except Exception as chroma_e: - self.logger.warning(f"Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}") + self.logger.warning( + f'Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}' + ) raise finally: session.close() @@ -219,7 +257,7 @@ class RAG_Manager: self.logger.info(f"Starting data retrieval process for query: '{query}'") try: retrieved_chunks = await self.retriever.retrieve(query) - self.logger.info(f"Successfully retrieved {len(retrieved_chunks)} chunks for query.") + self.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.') return retrieved_chunks except Exception as e: self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) @@ -230,32 +268,32 @@ class RAG_Manager: Deletes all data associated with a specific file ID, including its chunks and vectors, and the file record itself. """ - self.logger.info(f"Starting data deletion process for file_id: {file_id}") + self.logger.info(f'Starting data deletion process for file_id: {file_id}') session = SessionLocal() try: # 1. 从 ChromaDB 删除 embeddings await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) - self.logger.info(f"Deleted embeddings from ChromaDB for file_id: {file_id}") + self.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}') # 2. 删除与文件关联的 chunks 记录 chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() for chunk in chunks_to_delete: session.delete(chunk) - self.logger.info(f"Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}") + self.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}') # 3. 删除文件记录本身 file_to_delete = session.query(File).filter_by(id=file_id).first() if file_to_delete: session.delete(file_to_delete) - self.logger.info(f"Deleted file record for file_id: {file_id}") + self.logger.info(f'Deleted file record for file_id: {file_id}') else: - self.logger.warning(f"File with ID {file_id} not found in database. Skipping deletion of file record.") + self.logger.warning(f'File with ID {file_id} not found in database. Skipping deletion of file record.') session.commit() - self.logger.info(f"Successfully completed data deletion for file_id: {file_id}") + self.logger.info(f'Successfully completed data deletion for file_id: {file_id}') except Exception as e: session.rollback() - self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True) + self.logger.error(f'Error deleting data for file_id {file_id}: {str(e)}', exc_info=True) raise finally: session.close() @@ -265,27 +303,27 @@ class RAG_Manager: Deletes a knowledge base and all associated files, chunks, and vectors. This involves querying for associated files and then deleting them. """ - self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}") - session = SessionLocal() # 使用新的会话来获取 KB 和关联文件 + self.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}') + session = SessionLocal() # 使用新的会话来获取 KB 和关联文件 try: kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() if not kb_to_delete: - self.logger.warning(f"Knowledge Base with ID {kb_id} not found.") + self.logger.warning(f'Knowledge Base with ID {kb_id} not found.') return # 获取所有关联的文件,通过 File 表的 kb_id 字段查询 files_to_delete = session.query(File).filter_by(kb_id=kb_id).all() - + # 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话 - session.close() + session.close() # 遍历删除每个关联文件及其数据 for file_obj in files_to_delete: try: await self.delete_data_by_file_id(file_obj.id) except Exception as file_del_e: - self.logger.error(f"Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}") + self.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}') # 记录错误但继续,尝试删除其他文件 # 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身 @@ -296,12 +334,14 @@ class RAG_Manager: if kb_final_delete: session.delete(kb_final_delete) session.commit() - self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}") + self.logger.info(f'Successfully deleted knowledge base with ID: {kb_id}') else: - self.logger.warning(f"Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.") + self.logger.warning( + f'Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.' + ) except Exception as kb_del_e: session.rollback() - self.logger.error(f"Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}", exc_info=True) + self.logger.error(f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', exc_info=True) raise finally: session.close() @@ -310,57 +350,57 @@ class RAG_Manager: # 如果在最初获取 KB 或文件列表时出错 if session.is_active: session.rollback() - self.logger.error(f"Error during overall knowledge base deletion for ID {kb_id}: {str(e)}", exc_info=True) + self.logger.error(f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', exc_info=True) raise finally: if session.is_active: session.close() - - async def get_file_content_by_file_id(self, file_id: str) -> str: - file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id) _, ext = os.path.splitext(file_id.lower()) ext = ext.lstrip('.') try: - text = file_bytes.decode("utf-8") + text = file_bytes.decode('utf-8') except UnicodeDecodeError: - return "[非文本文件或编码无法识别]" + return '[非文本文件或编码无法识别]' - if ext in ["txt", "md", "csv", "log", "py", "html"]: + if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']: return text else: - return f"[未知类型: .{ext}]" - + return f'[未知类型: .{ext}]' + async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None: """ Associates a file with a knowledge base by updating the kb_id in the File table. """ - self.logger.info(f"Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}") + self.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}') session = SessionLocal() try: # 查询知识库是否存在 kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first() if not kb: - self.logger.error(f"Knowledge Base with UUID {knowledge_base_uuid} not found.") + self.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.') return # 更新文件的 kb_id file_to_update = session.query(File).filter_by(id=file_id).first() if not file_to_update: - self.logger.error(f"File with ID {file_id} not found.") + self.logger.error(f'File with ID {file_id} not found.') return file_to_update.kb_id = kb.id session.commit() - self.logger.info(f"Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}") + self.logger.info( + f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}' + ) except Exception as e: session.rollback() - self.logger.error(f"Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}", exc_info=True) + self.logger.error( + f'Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}', + exc_info=True, + ) finally: session.close() - - From 367d04d0f073f1f2b052976e3719944fd528178c Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 11:28:43 +0800 Subject: [PATCH 22/60] fix: success method bad params --- .../http/controller/groups/knowledge/base.py | 56 ++++++++++++------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index cf5bb44e..bfbbbe10 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -17,16 +17,20 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): } for kb in knowledge_bases ] - return self.success(code=0, data={'bases': bases_list}, msg='ok') + return self.success(data={'bases': bases_list}) # POST: create a new knowledge base json_data = await quart.request.json knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( json_data.get('name'), json_data.get('description') ) - return self.success(code=0, data={'uuid': knowledge_base_uuid}, msg='ok') + return self.success(data={'uuid': knowledge_base_uuid}) - @self.route('/', methods=['GET', 'DELETE'], endpoint='handle_specific_knowledge_base') + @self.route( + '/', + methods=['GET', 'DELETE'], + endpoint='handle_specific_knowledge_base', + ) async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(int(knowledge_base_uuid)) @@ -35,40 +39,50 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): return self.http_status(404, -1, 'knowledge base not found') return self.success( - code=0, data={ 'name': knowledge_base.name, 'description': knowledge_base.description, 'uuid': knowledge_base.id, }, - msg='ok', ) elif quart.request.method == 'DELETE': await self.ap.knowledge_base_service.delete_kb_by_id(int(knowledge_base_uuid)) - return self.success(code=0, msg='ok') + return self.success({}) - @self.route('//files', methods=['GET'], endpoint='get_knowledge_base_files') + @self.route( + '//files', + methods=['GET'], + endpoint='get_knowledge_base_files', + ) async def get_knowledge_base_files(knowledge_base_uuid: str) -> str: files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(int(knowledge_base_uuid)) return self.success( - code=0, - data=[ - { - 'id': file.id, - 'file_name': file.file_name, - 'status': file.status, - } - for file in files - ], - msg='ok', + data={ + 'files': [ + { + 'id': file.id, + 'file_name': file.file_name, + 'status': file.status, + } + for file in files + ], + } ) - @self.route('//files/', methods=['DELETE'], endpoint='delete_specific_file_in_kb') + @self.route( + '//files/', + methods=['DELETE'], + endpoint='delete_specific_file_in_kb', + ) async def delete_specific_file_in_kb(file_id: str) -> str: await self.ap.knowledge_base_service.delete_data_by_file_id(int(file_id)) - return self.success(code=0, msg='ok') + return self.success({}) - @self.route('//files', methods=['POST'], endpoint='relate_file_with_kb') + @self.route( + '//files', + methods=['POST'], + endpoint='relate_file_with_kb', + ) async def relate_file_id_with_kb(knowledge_base_uuid: str, file_id: str) -> str: if 'file' not in quart.request.files: return self.http_status(400, -1, 'No file part in the request') @@ -80,4 +94,4 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): # 调用服务层方法将文件与知识库关联 await self.ap.knowledge_base_service.relate_file_id_with_kb(int(knowledge_base_uuid), int(file_id)) - return self.success(code=0, data={}, msg='ok') + return self.success({}) From 9ba1ad5bd38e48a3315f165ce88e73c3399b56fc Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 16:38:08 +0800 Subject: [PATCH 23/60] fix: bugs --- .../http/controller/groups/knowledge/base.py | 14 +++--- pkg/entity/persistence/rag.py | 12 ++--- pkg/rag/knowledge/mgr.py | 47 +------------------ 3 files changed, 15 insertions(+), 58 deletions(-) diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index bfbbbe10..b5a48d29 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -14,17 +14,19 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): 'uuid': kb.id, 'name': kb.name, 'description': kb.description, + 'embedding_model_uuid': kb.embedding_model_uuid, + 'top_k': kb.top_k, } for kb in knowledge_bases ] return self.success(data={'bases': bases_list}) - # POST: create a new knowledge base - json_data = await quart.request.json - knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( - json_data.get('name'), json_data.get('description') - ) - return self.success(data={'uuid': knowledge_base_uuid}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( + json_data.get('name'), json_data.get('description'), json_data.get('embedding_model_uuid') + ) + return self.success(data={'uuid': knowledge_base_uuid}) @self.route( '/', diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py index 175720f1..1657196a 100644 --- a/pkg/entity/persistence/rag.py +++ b/pkg/entity/persistence/rag.py @@ -5,13 +5,10 @@ import os Base = declarative_base() -DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./rag_knowledge.db") +DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db') -engine = create_engine( - DATABASE_URL, - connect_args={"check_same_thread": False} -) +engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False}) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -20,7 +17,8 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) def create_db_and_tables(): """Creates all database tables defined in the Base.""" Base.metadata.create_all(bind=engine) - print("Database tables created or already exist.") + print('Database tables created or already exist.') + class KnowledgeBase(Base): __tablename__ = 'kb' @@ -28,7 +26,7 @@ class KnowledgeBase(Base): name = Column(String, index=True) description = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) - embedding_model = Column(String, default='') + embedding_model_uuid = Column(String, default='') top_k = Column(Integer, default=5) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 7d1787e0..5d4eece9 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -6,11 +6,7 @@ import asyncio import uuid from pkg.rag.knowledge.services.parser import FileParser from pkg.rag.knowledge.services.chunker import Chunker -from pkg.rag.knowledge.services.embedder import Embedder -from pkg.rag.knowledge.services.retriever import Retriever from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk -from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory -from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager from pkg.core import app @@ -20,8 +16,6 @@ class RAGManager: def __init__(self, ap: app.Application, logger: logging.Logger = None): self.ap = ap self.logger = logger or logging.getLogger(__name__) - self.embedding_model_type = None - self.embedding_model_name = None self.chroma_manager = None self.parser = FileParser() self.chunker = Chunker() @@ -32,50 +26,13 @@ class RAGManager: """Initializes the RAG system by creating database tables.""" await asyncio.to_thread(create_db_and_tables) - async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str): - """ - Creates and configures the specific embedding model and ChromaDB manager. - This must be called before performing embedding or retrieval operations. - """ - self.embedding_model_type = embedding_model_type - self.embedding_model_name = embedding_model_name - - try: - model = EmbeddingModelFactory.create_model( - model_type=self.embedding_model_type, model_name_key=self.embedding_model_name - ) - self.logger.info( - f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}" - ) - except Exception as e: - self.logger.critical( - f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}" - ) - raise RuntimeError('Failed to initialize RAG_Manager due to embedding model issues.') - - self.chroma_manager = ChromaIndexManager( - collection_name=f'rag_collection_{self.embedding_model_name.replace("-", "_")}' - ) - self.embedder = Embedder( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name, - chroma_manager=self.chroma_manager, - ) - self.retriever = Retriever( - model_type=self.embedding_model_type, - model_name_key=self.embedding_model_name, - chroma_manager=self.chroma_manager, - ) - async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = '', top_k: int = 5): """ Creates a new knowledge base if it doesn't already exist. """ try: - if not self.embedding_model_type or not kb_name: - raise ValueError( - 'Embedding model type and knowledge base name must be set before creating a knowledge base.' - ) + if not kb_name: + raise ValueError('Knowledge base name must be set while creating.') def _create_kb_sync(): session = SessionLocal() From 7d5503dab201a44112b120e2bde6b47e3117f610 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 16:49:55 +0800 Subject: [PATCH 24/60] fix: bug --- pkg/rag/knowledge/mgr.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 5d4eece9..4da10a09 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -26,7 +26,9 @@ class RAGManager: """Initializes the RAG system by creating database tables.""" await asyncio.to_thread(create_db_and_tables) - async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = '', top_k: int = 5): + async def create_knowledge_base( + self, kb_name: str, kb_description: str, embedding_model_uuid: str = '', top_k: int = 5 + ): """ Creates a new knowledge base if it doesn't already exist. """ @@ -43,7 +45,7 @@ class RAGManager: new_kb = KnowledgeBase( name=kb_name, description=kb_description, - embedding_model=embedding_model, + embedding_model_uuid=embedding_model_uuid, top_k=top_k, id=id, ) From 815cdf8b4acb25e429cf2b15acb90b408340f4a0 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 17:22:43 +0800 Subject: [PATCH 25/60] feat: kb dialog action --- .../knowledge/components/kb-form/KBForm.tsx | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx index 9ae51656..0d4f0909 100644 --- a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx +++ b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx @@ -23,6 +23,7 @@ import { SelectTrigger, SelectValue, } from '@/components/ui/select'; +import { KnowledgeBase } from '@/app/infra/entities/api'; const getFormSchema = (t: (key: string) => string) => z.object({ @@ -81,11 +82,40 @@ export default function KBForm({ ); }; + const onSubmit = (data: z.infer) => { + console.log('data', data); + + if (initKbId) { + // update knowledge base + const updateKb: KnowledgeBase = { + name: data.name, + description: data.description, + embedding_model_uuid: data.embeddingModelUUID, + }; + } else { + // create knowledge base + const newKb: KnowledgeBase = { + name: data.name, + description: data.description, + embedding_model_uuid: data.embeddingModelUUID, + }; + httpClient + .createKnowledgeBase(newKb) + .then((res) => { + console.log('create knowledge base success', res); + onNewKbCreated(res.uuid); + }) + .catch((err) => { + console.error('create knowledge base failed', err); + }); + } + }; + return ( <>
From 14c161b73316e268e91b30f6705af9bec0652e6a Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Fri, 11 Jul 2025 18:14:03 +0800 Subject: [PATCH 26/60] fix: create knwoledge base issue --- pkg/entity/persistence/rag.py | 26 ++++++++++---------------- pkg/rag/knowledge/mgr.py | 34 +++++++++++++++++----------------- 2 files changed, 27 insertions(+), 33 deletions(-) diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py index 1657196a..95a78712 100644 --- a/pkg/entity/persistence/rag.py +++ b/pkg/entity/persistence/rag.py @@ -1,19 +1,17 @@ -from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, LargeBinary +from sqlalchemy import create_engine, Column, String, Text, DateTime, LargeBinary, Integer from sqlalchemy.orm import declarative_base, sessionmaker from datetime import datetime import os - Base = declarative_base() DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db') +print("Using database URL:", DATABASE_URL) engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False}) - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - def create_db_and_tables(): """Creates all database tables defined in the Base.""" Base.metadata.create_all(bind=engine) @@ -22,35 +20,31 @@ def create_db_and_tables(): class KnowledgeBase(Base): __tablename__ = 'kb' - id = Column(Integer, primary_key=True, index=True) + id = Column(String, primary_key=True, index=True) name = Column(String, index=True) description = Column(Text) created_at = Column(DateTime, default=datetime.utcnow) embedding_model_uuid = Column(String, default='') top_k = Column(Integer, default=5) - class File(Base): __tablename__ = 'file' - id = Column(Integer, primary_key=True, index=True) - kb_id = Column(Integer, nullable=True) + id = Column(String, primary_key=True, index=True) + kb_id = Column(String, nullable=True) file_name = Column(String) path = Column(String) created_at = Column(DateTime, default=datetime.utcnow) file_type = Column(String) - status = Column(Integer, default=0) - + status = Column(String, default='0') class Chunk(Base): __tablename__ = 'chunks' - id = Column(Integer, primary_key=True, index=True) - file_id = Column(Integer, nullable=True) - + id = Column(String, primary_key=True, index=True) + file_id = Column(String, nullable=True) text = Column(Text) - class Vector(Base): __tablename__ = 'vectors' - id = Column(Integer, primary_key=True, index=True) - chunk_id = Column(Integer, nullable=True) + id = Column(String, primary_key=True, index=True) + chunk_id = Column(String, nullable=True) embedding = Column(LargeBinary) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 4da10a09..585a5075 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -41,7 +41,7 @@ class RAGManager: try: kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() if not kb: - id = uuid.uuid4().int + id = str(uuid.uuid4()) new_kb = KnowledgeBase( name=kb_name, description=kb_description, @@ -86,7 +86,7 @@ class RAGManager: self.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True) return [] - async def get_knowledge_base_by_id(self, kb_id: int): + async def get_knowledge_base_by_id(self, kb_id: str): """ Retrieves a specific knowledge base by its ID. """ @@ -104,7 +104,7 @@ class RAGManager: self.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True) return None - async def get_files_by_knowledge_base(self, kb_id: int): + async def get_files_by_knowledge_base(self, kb_id: str): """ Retrieves files associated with a specific knowledge base by querying the File table directly. """ @@ -153,7 +153,7 @@ class RAGManager: file_obj = None try: - # 1. 确保知识库存在或创建它 + kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() if not kb: kb = KnowledgeBase(name=kb_name, description=kb_description) @@ -164,7 +164,7 @@ class RAGManager: else: self.logger.info(f"Knowledge Base '{kb_name}' already exists.") - # 2. 添加文件记录到数据库,并直接关联 kb_id + file_name = os.path.basename(file_path) existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first() if existing_file: @@ -181,15 +181,15 @@ class RAGManager: f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}" ) - # 3. 解析文件内容 + text = await self.parser.parse(file_path) if not text: self.logger.warning(f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.') session.delete(file_obj) - session.commit() # 提交删除操作 + session.commit() return - # 4. 分块并嵌入/存储块 + chunks_texts = await self.chunker.chunk(text) self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) @@ -222,7 +222,7 @@ class RAGManager: self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) return [] - async def delete_data_by_file_id(self, file_id: int): + async def delete_data_by_file_id(self, file_id: str): """ Deletes all data associated with a specific file ID, including its chunks and vectors, and the file record itself. @@ -257,13 +257,13 @@ class RAGManager: finally: session.close() - async def delete_kb_by_id(self, kb_id: int): + async def delete_kb_by_id(self, kb_id: str): """ Deletes a knowledge base and all associated files, chunks, and vectors. This involves querying for associated files and then deleting them. """ self.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}') - session = SessionLocal() # 使用新的会话来获取 KB 和关联文件 + session = SessionLocal() try: kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() @@ -271,24 +271,24 @@ class RAGManager: self.logger.warning(f'Knowledge Base with ID {kb_id} not found.') return - # 获取所有关联的文件,通过 File 表的 kb_id 字段查询 + files_to_delete = session.query(File).filter_by(kb_id=kb_id).all() - # 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话 + session.close() - # 遍历删除每个关联文件及其数据 + for file_obj in files_to_delete: try: await self.delete_data_by_file_id(file_obj.id) except Exception as file_del_e: self.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}') - # 记录错误但继续,尝试删除其他文件 + - # 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身 + session = SessionLocal() try: - # 重新查询,确保对象是当前会话的一部分 + kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() if kb_final_delete: session.delete(kb_final_delete) From bd9331ce62f8aea92dbe63757b8c38696464ff87 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 20:57:09 +0800 Subject: [PATCH 27/60] fix: kb get api format --- .../http/controller/groups/knowledge/base.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index b5a48d29..70cf2b0c 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -24,7 +24,9 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): elif quart.request.method == 'POST': json_data = await quart.request.json knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base( - json_data.get('name'), json_data.get('description'), json_data.get('embedding_model_uuid') + json_data.get('name'), + json_data.get('description'), + json_data.get('embedding_model_uuid'), ) return self.success(data={'uuid': knowledge_base_uuid}) @@ -35,20 +37,22 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): ) async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str: if quart.request.method == 'GET': - knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(int(knowledge_base_uuid)) + knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid) if knowledge_base is None: return self.http_status(404, -1, 'knowledge base not found') return self.success( data={ - 'name': knowledge_base.name, - 'description': knowledge_base.description, - 'uuid': knowledge_base.id, - }, + 'base': { + 'name': knowledge_base.name, + 'description': knowledge_base.description, + 'uuid': knowledge_base.id, + }, + } ) elif quart.request.method == 'DELETE': - await self.ap.knowledge_base_service.delete_kb_by_id(int(knowledge_base_uuid)) + await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid) return self.success({}) @self.route( @@ -57,7 +61,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): endpoint='get_knowledge_base_files', ) async def get_knowledge_base_files(knowledge_base_uuid: str) -> str: - files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(int(knowledge_base_uuid)) + files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) return self.success( data={ 'files': [ @@ -77,7 +81,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): endpoint='delete_specific_file_in_kb', ) async def delete_specific_file_in_kb(file_id: str) -> str: - await self.ap.knowledge_base_service.delete_data_by_file_id(int(file_id)) + await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) return self.success({}) @self.route( @@ -95,5 +99,5 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): return self.http_status(400, -1, 'File ID is required') # 调用服务层方法将文件与知识库关联 - await self.ap.knowledge_base_service.relate_file_id_with_kb(int(knowledge_base_uuid), int(file_id)) + await self.ap.knowledge_base_service.relate_file_id_with_kb(knowledge_base_uuid, file_id) return self.success({}) From 2ed3b687904feaa48ee27faf6793a91486ca3b31 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 20:58:51 +0800 Subject: [PATCH 28/60] fix: kb get api not contains model uuid --- pkg/api/http/controller/groups/knowledge/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index 70cf2b0c..594fe7bf 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -48,6 +48,8 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): 'name': knowledge_base.name, 'description': knowledge_base.description, 'uuid': knowledge_base.id, + 'embedding_model_uuid': knowledge_base.embedding_model_uuid, + 'top_k': knowledge_base.top_k, }, } ) From a79a22a74d0ea19b04c23f0bc5f9a25657012f9d Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 21:30:47 +0800 Subject: [PATCH 29/60] fix: api bug --- .../http/controller/groups/knowledge/base.py | 57 ++++++++----------- 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index 594fe7bf..b3fd50ea 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -59,23 +59,34 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): @self.route( '//files', - methods=['GET'], + methods=['GET', 'POST'], endpoint='get_knowledge_base_files', ) async def get_knowledge_base_files(knowledge_base_uuid: str) -> str: - files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) - return self.success( - data={ - 'files': [ - { - 'id': file.id, - 'file_name': file.file_name, - 'status': file.status, - } - for file in files - ], - } - ) + if quart.request.method == 'GET': + files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid) + return self.success( + data={ + 'files': [ + { + 'id': file.id, + 'file_name': file.file_name, + 'status': file.status, + } + for file in files + ], + } + ) + + elif quart.request.method == 'POST': + json_data = await quart.request.json + file_id = json_data.get('file_id') + if not file_id: + return self.http_status(400, -1, 'File ID is required') + + # 调用服务层方法将文件与知识库关联 + await self.ap.knowledge_base_service.relate_file_id_with_kb(knowledge_base_uuid, file_id) + return self.success({}) @self.route( '//files/', @@ -85,21 +96,3 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): async def delete_specific_file_in_kb(file_id: str) -> str: await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) return self.success({}) - - @self.route( - '//files', - methods=['POST'], - endpoint='relate_file_with_kb', - ) - async def relate_file_id_with_kb(knowledge_base_uuid: str, file_id: str) -> str: - if 'file' not in quart.request.files: - return self.http_status(400, -1, 'No file part in the request') - - json_data = await quart.request.json - file_id = json_data.get('file_id') - if not file_id: - return self.http_status(400, -1, 'File ID is required') - - # 调用服务层方法将文件与知识库关联 - await self.ap.knowledge_base_service.relate_file_id_with_kb(knowledge_base_uuid, file_id) - return self.success({}) From 6d788cadbc24348355810e7c7ae2d0b58f2e6124 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 21:37:31 +0800 Subject: [PATCH 30/60] fix: the fucking logger --- pkg/rag/knowledge/mgr.py | 102 ++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 54 deletions(-) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 585a5075..89e5b393 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -1,6 +1,5 @@ # rag_manager.py from __future__ import annotations -import logging import os import asyncio import uuid @@ -13,9 +12,8 @@ from pkg.core import app class RAGManager: ap: app.Application - def __init__(self, ap: app.Application, logger: logging.Logger = None): + def __init__(self, ap: app.Application): self.ap = ap - self.logger = logger or logging.getLogger(__name__) self.chroma_manager = None self.parser = FileParser() self.chunker = Chunker() @@ -52,20 +50,20 @@ class RAGManager: session.add(new_kb) session.commit() session.refresh(new_kb) - self.logger.info(f"Knowledge Base '{kb_name}' created.") + self.ap.logger.info(f"Knowledge Base '{kb_name}' created.") return new_kb.id else: - self.logger.info(f"Knowledge Base '{kb_name}' already exists.") + self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.") except Exception as e: session.rollback() - self.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True) + self.ap.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True) raise finally: session.close() return await asyncio.to_thread(_create_kb_sync) except Exception as e: - self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) + self.ap.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) raise async def get_all_knowledge_bases(self): @@ -83,7 +81,7 @@ class RAGManager: return await asyncio.to_thread(_get_all_kbs_sync) except Exception as e: - self.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True) + self.ap.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True) return [] async def get_knowledge_base_by_id(self, kb_id: str): @@ -101,7 +99,7 @@ class RAGManager: return await asyncio.to_thread(_get_kb_sync, kb_id) except Exception as e: - self.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True) + self.ap.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True) return None async def get_files_by_knowledge_base(self, kb_id: str): @@ -119,7 +117,7 @@ class RAGManager: return await asyncio.to_thread(_get_files_sync, kb_id) except Exception as e: - self.logger.error(f'Error retrieving files for knowledge base ID {kb_id}: {str(e)}', exc_info=True) + self.ap.logger.error(f'Error retrieving files for knowledge base ID {kb_id}: {str(e)}', exc_info=True) return [] async def get_all_files(self): @@ -138,7 +136,7 @@ class RAGManager: return await asyncio.to_thread(_get_all_files_sync) except Exception as e: - self.logger.error(f'Error retrieving all files: {str(e)}', exc_info=True) + self.ap.logger.error(f'Error retrieving all files: {str(e)}', exc_info=True) return [] async def store_data( @@ -148,27 +146,25 @@ class RAGManager: Parses, chunks, embeds, and stores data from a given file into the RAG system. Associates the file with a knowledge base using kb_id in the File table. """ - self.logger.info(f'Starting data storage process for file: {file_path}') + self.ap.logger.info(f'Starting data storage process for file: {file_path}') session = SessionLocal() file_obj = None try: - kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() if not kb: kb = KnowledgeBase(name=kb_name, description=kb_description) session.add(kb) session.commit() session.refresh(kb) - self.logger.info(f"Knowledge Base '{kb_name}' created during store_data.") + self.ap.logger.info(f"Knowledge Base '{kb_name}' created during store_data.") else: - self.logger.info(f"Knowledge Base '{kb_name}' already exists.") + self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.") - file_name = os.path.basename(file_path) existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first() if existing_file: - self.logger.warning( + self.ap.logger.warning( f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage." ) return @@ -177,32 +173,32 @@ class RAGManager: session.add(file_obj) session.commit() session.refresh(file_obj) - self.logger.info( + self.ap.logger.info( f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}" ) - text = await self.parser.parse(file_path) if not text: - self.logger.warning(f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.') + self.ap.logger.warning( + f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.' + ) session.delete(file_obj) session.commit() return - chunks_texts = await self.chunker.chunk(text) - self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") + self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) - self.logger.info(f'Data storage process completed for file: {file_path}') + self.ap.logger.info(f'Data storage process completed for file: {file_path}') except Exception as e: session.rollback() - self.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True) + self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True) if file_obj and file_obj.id: try: await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id) except Exception as chroma_e: - self.logger.warning( + self.ap.logger.warning( f'Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}' ) raise @@ -213,13 +209,13 @@ class RAGManager: """ Retrieves relevant data chunks based on a given query using the configured retriever. """ - self.logger.info(f"Starting data retrieval process for query: '{query}'") + self.ap.logger.info(f"Starting data retrieval process for query: '{query}'") try: retrieved_chunks = await self.retriever.retrieve(query) - self.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.') + self.ap.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.') return retrieved_chunks except Exception as e: - self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) + self.ap.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) return [] async def delete_data_by_file_id(self, file_id: str): @@ -227,32 +223,34 @@ class RAGManager: Deletes all data associated with a specific file ID, including its chunks and vectors, and the file record itself. """ - self.logger.info(f'Starting data deletion process for file_id: {file_id}') + self.ap.logger.info(f'Starting data deletion process for file_id: {file_id}') session = SessionLocal() try: # 1. 从 ChromaDB 删除 embeddings await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) - self.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}') + self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}') # 2. 删除与文件关联的 chunks 记录 chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() for chunk in chunks_to_delete: session.delete(chunk) - self.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}') + self.ap.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}') # 3. 删除文件记录本身 file_to_delete = session.query(File).filter_by(id=file_id).first() if file_to_delete: session.delete(file_to_delete) - self.logger.info(f'Deleted file record for file_id: {file_id}') + self.ap.logger.info(f'Deleted file record for file_id: {file_id}') else: - self.logger.warning(f'File with ID {file_id} not found in database. Skipping deletion of file record.') + self.ap.logger.warning( + f'File with ID {file_id} not found in database. Skipping deletion of file record.' + ) session.commit() - self.logger.info(f'Successfully completed data deletion for file_id: {file_id}') + self.ap.logger.info(f'Successfully completed data deletion for file_id: {file_id}') except Exception as e: session.rollback() - self.logger.error(f'Error deleting data for file_id {file_id}: {str(e)}', exc_info=True) + self.ap.logger.error(f'Error deleting data for file_id {file_id}: {str(e)}', exc_info=True) raise finally: session.close() @@ -262,45 +260,39 @@ class RAGManager: Deletes a knowledge base and all associated files, chunks, and vectors. This involves querying for associated files and then deleting them. """ - self.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}') - session = SessionLocal() + self.ap.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}') + session = SessionLocal() try: kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() if not kb_to_delete: - self.logger.warning(f'Knowledge Base with ID {kb_id} not found.') + self.ap.logger.warning(f'Knowledge Base with ID {kb_id} not found.') return - files_to_delete = session.query(File).filter_by(kb_id=kb_id).all() - session.close() - for file_obj in files_to_delete: try: await self.delete_data_by_file_id(file_obj.id) except Exception as file_del_e: - self.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}') - + self.ap.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}') - session = SessionLocal() try: - kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first() if kb_final_delete: session.delete(kb_final_delete) session.commit() - self.logger.info(f'Successfully deleted knowledge base with ID: {kb_id}') + self.ap.logger.info(f'Successfully deleted knowledge base with ID: {kb_id}') else: - self.logger.warning( + self.ap.logger.warning( f'Knowledge Base with ID {kb_id} not found after file deletion, skipping KB deletion.' ) except Exception as kb_del_e: session.rollback() - self.logger.error(f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', exc_info=True) + self.ap.logger.error(f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', exc_info=True) raise finally: session.close() @@ -309,7 +301,9 @@ class RAGManager: # 如果在最初获取 KB 或文件列表时出错 if session.is_active: session.rollback() - self.logger.error(f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', exc_info=True) + self.ap.logger.error( + f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', exc_info=True + ) raise finally: if session.is_active: @@ -335,29 +329,29 @@ class RAGManager: """ Associates a file with a knowledge base by updating the kb_id in the File table. """ - self.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}') + self.ap.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}') session = SessionLocal() try: # 查询知识库是否存在 kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first() if not kb: - self.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.') + self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.') return # 更新文件的 kb_id file_to_update = session.query(File).filter_by(id=file_id).first() if not file_to_update: - self.logger.error(f'File with ID {file_id} not found.') + self.ap.logger.error(f'File with ID {file_id} not found.') return file_to_update.kb_id = kb.id session.commit() - self.logger.info( + self.ap.logger.info( f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}' ) except Exception as e: session.rollback() - self.logger.error( + self.ap.logger.error( f'Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}', exc_info=True, ) From fe122281fdaa056386098daeffe999eb1cf0325a Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 11 Jul 2025 21:40:42 +0800 Subject: [PATCH 31/60] feat(fe): component for available apis --- .../home/bots/components/bot-form/BotForm.tsx | 1 + web/src/app/home/knowledge/KBDetailDialog.tsx | 4 +++- .../knowledge/components/kb-docs/KBDoc.tsx | 5 ++++ .../kb-docs/doc-card/DocumentCard.tsx | 9 +++++++ .../knowledge/components/kb-form/KBForm.tsx | 24 ++++++++++++++++++- 5 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 web/src/app/home/knowledge/components/kb-docs/doc-card/DocumentCard.tsx diff --git a/web/src/app/home/bots/components/bot-form/BotForm.tsx b/web/src/app/home/bots/components/bot-form/BotForm.tsx index fe36d33b..e4b6d40e 100644 --- a/web/src/app/home/bots/components/bot-form/BotForm.tsx +++ b/web/src/app/home/bots/components/bot-form/BotForm.tsx @@ -212,6 +212,7 @@ export default function BotForm({ }); setAdapterNameToDynamicConfigMap(adapterNameToDynamicConfigMap); } + async function getBotConfig( botId: string, ): Promise> { diff --git a/web/src/app/home/knowledge/KBDetailDialog.tsx b/web/src/app/home/knowledge/KBDetailDialog.tsx index e3ab4f9d..d31306bd 100644 --- a/web/src/app/home/knowledge/KBDetailDialog.tsx +++ b/web/src/app/home/knowledge/KBDetailDialog.tsx @@ -24,6 +24,7 @@ import { z } from 'zod'; // import { httpClient } from '@/app/infra/http/HttpClient'; // import { KnowledgeBase } from '@/app/infra/entities/api'; import KBForm from '@/app/home/knowledge/components/kb-form/KBForm'; +import KBDoc from '@/app/home/knowledge/components/kb-docs/KBDoc'; interface KBDetailDialogProps { open: boolean; @@ -48,6 +49,7 @@ export default function KBDetailDialog({ const { t } = useTranslation(); const [kbId, setKbId] = useState(propKbId); const [activeMenu, setActiveMenu] = useState('metadata'); + const [fileId, setFileId] = useState(undefined); // const [showDeleteConfirm, setShowDeleteConfirm] = useState(false); useEffect(() => { @@ -177,7 +179,7 @@ export default function KBDetailDialog({ onNewKbCreated={onNewKbCreated} /> )} - {activeMenu === 'documents' &&
documents
} + {activeMenu === 'documents' && }
{activeMenu === 'metadata' && ( diff --git a/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx b/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx index e69de29b..5cc9a850 100644 --- a/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx +++ b/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx @@ -0,0 +1,5 @@ +import { useEffect, useState } from 'react'; + +export default function KBDoc({ kbId }: { kbId: string }) { + return
Documents
; +} diff --git a/web/src/app/home/knowledge/components/kb-docs/doc-card/DocumentCard.tsx b/web/src/app/home/knowledge/components/kb-docs/doc-card/DocumentCard.tsx new file mode 100644 index 00000000..23a884ba --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-docs/doc-card/DocumentCard.tsx @@ -0,0 +1,9 @@ +export default function DocumentCard({ + kbId, + fileId, +}: { + kbId: string; + fileId: string; +}) { + return
; +} diff --git a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx index 0d4f0909..b56c327b 100644 --- a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx +++ b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx @@ -67,9 +67,31 @@ export default function KBForm({ >([]); useEffect(() => { - getEmbeddingModelNameList(); + getEmbeddingModelNameList().then(() => { + if (initKbId) { + getKbConfig(initKbId).then((val) => { + form.setValue('name', val.name); + form.setValue('description', val.description); + form.setValue('embeddingModelUUID', val.embeddingModelUUID); + }); + } + }); }, []); + const getKbConfig = async ( + kbId: string, + ): Promise> => { + return new Promise((resolve, reject) => { + httpClient.getKnowledgeBase(kbId).then((res) => { + resolve({ + name: res.base.name, + description: res.base.description, + embeddingModelUUID: res.base.embedding_model_uuid, + }); + }); + }); + }; + const getEmbeddingModelNameList = async () => { const resp = await httpClient.getProviderEmbeddingModels(); setEmbeddingModelNameList( From f395cac893a3e416b93f25f1a9af1c877514b0ea Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Sat, 12 Jul 2025 01:07:49 +0800 Subject: [PATCH 32/60] fix: embbeding and chunking --- .gitignore | 3 +- pkg/entity/persistence/rag.py | 5 +- pkg/rag/knowledge/mgr.py | 84 ++++++++++++-------------- pkg/rag/knowledge/services/chunker.py | 41 ++++++------- pkg/rag/knowledge/services/embedder.py | 2 +- 5 files changed, 64 insertions(+), 71 deletions(-) diff --git a/.gitignore b/.gitignore index 2869b7cc..db62bdca 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,5 @@ botpy.log* test.py /web_ui .venv/ -uv.lock \ No newline at end of file +uv.lock +/test \ No newline at end of file diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py index 95a78712..9ca84741 100644 --- a/pkg/entity/persistence/rag.py +++ b/pkg/entity/persistence/rag.py @@ -2,6 +2,7 @@ from sqlalchemy import create_engine, Column, String, Text, DateTime, LargeBinar from sqlalchemy.orm import declarative_base, sessionmaker from datetime import datetime import os +import uuid Base = declarative_base() DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db') @@ -35,11 +36,11 @@ class File(Base): path = Column(String) created_at = Column(DateTime, default=datetime.utcnow) file_type = Column(String) - status = Column(String, default='0') + status = Column(Integer, default=0) # 0: uploaded and processing, 1: completed, 2: failed class Chunk(Base): __tablename__ = 'chunks' - id = Column(String, primary_key=True, index=True) + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) file_id = Column(String, nullable=True) text = Column(Text) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 89e5b393..09023d03 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -7,6 +7,9 @@ from pkg.rag.knowledge.services.parser import FileParser from pkg.rag.knowledge.services.chunker import Chunker from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk from pkg.core import app +from pkg.rag.knowledge.services.embedder import Embedder +from pkg.rag.knowledge.services.retriever import Retriever +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager class RAGManager: @@ -14,11 +17,12 @@ class RAGManager: def __init__(self, ap: app.Application): self.ap = ap - self.chroma_manager = None + self.chroma_manager = ChromaIndexManager() self.parser = FileParser() self.chunker = Chunker() - self.embedder = None - self.retriever = None + # Initialize Embedder with targeted model type and name + self.embedder = Embedder(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager) + self.retriever = Retriever(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager) async def initialize_rag_system(self): """Initializes the RAG system by creating database tables.""" @@ -140,7 +144,7 @@ class RAGManager: return [] async def store_data( - self, file_path: str, kb_name: str, file_type: str, kb_description: str = 'Default knowledge base' + self, file_path: str, kb_id: str, file_type: str, file_id: str = None ): """ Parses, chunks, embeds, and stores data from a given file into the RAG system. @@ -151,58 +155,35 @@ class RAGManager: file_obj = None try: - kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() + kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() if not kb: - kb = KnowledgeBase(name=kb_name, description=kb_description) - session.add(kb) - session.commit() - session.refresh(kb) - self.ap.logger.info(f"Knowledge Base '{kb_name}' created during store_data.") + self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ') + self.ap.logger.info(f'Created Knowledge Base with ID: {kb_id}') else: - self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.") + self.ap.logger.info(f"Knowledge Base '{kb_id}' already exists.") file_name = os.path.basename(file_path) - existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first() - if existing_file: - self.ap.logger.warning( - f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage." - ) - return - - file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type) - session.add(file_obj) - session.commit() - session.refresh(file_obj) - self.ap.logger.info( - f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}" - ) - text = await self.parser.parse(file_path) if not text: self.ap.logger.warning( - f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.' + f'No text extracted from file {file_path}. ' ) - session.delete(file_obj) - session.commit() return chunks_texts = await self.chunker.chunk(text) self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") - await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) + await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts) self.ap.logger.info(f'Data storage process completed for file: {file_path}') except Exception as e: session.rollback() self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True) - if file_obj and file_obj.id: - try: - await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id) - except Exception as chroma_e: - self.ap.logger.warning( - f'Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}' - ) raise finally: + if file_id: + file_obj = session.query(File).filter_by(id=file_id).first() + if file_obj: + file_obj.status = 1 session.close() async def retrieve_data(self, query: str): @@ -245,7 +226,6 @@ class RAGManager: self.ap.logger.warning( f'File with ID {file_id} not found in database. Skipping deletion of file record.' ) - session.commit() self.ap.logger.info(f'Successfully completed data deletion for file_id: {file_id}') except Exception as e: @@ -338,13 +318,13 @@ class RAGManager: self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.') return - # 更新文件的 kb_id - file_to_update = session.query(File).filter_by(id=file_id).first() - if not file_to_update: - self.ap.logger.error(f'File with ID {file_id} not found.') + if not self.ap.storage_mgr.storage_provider.exists(file_id): + self.ap.logger.error(f'File with ID {file_id} does not exist.') return - - file_to_update.kb_id = kb.id + self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.') + # add new file record + file_to_update = File(id=file_id, kb_id=kb.id) + session.add(file_to_update) session.commit() self.ap.logger.info( f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}' @@ -356,4 +336,20 @@ class RAGManager: exc_info=True, ) finally: + # 进行文件解析 + try: + await self.store_data( + file_path = os.path.join('data', 'storage', file_id), + kb_id=knowledge_base_uuid, + file_type=os.path.splitext(file_id)[1].lstrip('.'), + file_id=file_id + ) + except Exception as store_e: + # 如果存储数据时出错,更新文件状态为失败 + file_obj = session.query(File).filter_by(id=file_id).first() + if file_obj: + file_obj.status = 2 + session.commit() + self.ap.logger.error(f'Error storing data for file ID {file_id}', exc_info=True) + session.close() diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py index 17202a7a..2db7c104 100644 --- a/pkg/rag/knowledge/services/chunker.py +++ b/pkg/rag/knowledge/services/chunker.py @@ -24,33 +24,28 @@ class Chunker(BaseService): """ if not text: return [] - - # Simple whitespace-based splitting for demonstration - # For more advanced chunking, consider libraries like LangChain's text splitters - words = text.split() - chunks = [] - current_chunk = [] + # words = text.split() + # chunks = [] + # current_chunk = [] - for word in words: - current_chunk.append(word) - if len(current_chunk) > self.chunk_size: - chunks.append(" ".join(current_chunk[:self.chunk_size])) - current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:] + # for word in words: + # current_chunk.append(word) + # if len(current_chunk) > self.chunk_size: + # chunks.append(" ".join(current_chunk[:self.chunk_size])) + # current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:] - if current_chunk: - chunks.append(" ".join(current_chunk)) + # if current_chunk: + # chunks.append(" ".join(current_chunk)) # A more robust chunking strategy (e.g., using recursive character text splitter) - # from langchain.text_splitter import RecursiveCharacterTextSplitter - # text_splitter = RecursiveCharacterTextSplitter( - # chunk_size=self.chunk_size, - # chunk_overlap=self.chunk_overlap, - # length_function=len, - # is_separator_regex=False, - # ) - # return text_splitter.split_text(text) - - return [chunk for chunk in chunks if chunk.strip()] # Filter out empty chunks + from langchain.text_splitter import RecursiveCharacterTextSplitter + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + length_function=len, + is_separator_regex=False, + ) + return text_splitter.split_text(text) async def chunk(self, text: str) -> List[str]: """ diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 7e20b19a..063ae79e 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -12,7 +12,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Impor logger = logging.getLogger(__name__) class Embedder(BaseService): - def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): + def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager = None): super().__init__() self.logger = logging.getLogger(self.__class__.__name__) self.model_type = model_type From 9f43097361c33eb1ca8c60a2268d657fa0dbbe86 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Sat, 12 Jul 2025 01:21:02 +0800 Subject: [PATCH 33/60] fix: ensure File.status is set correctly after storing data to avoid null values --- pkg/rag/knowledge/mgr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 09023d03..a5d5f513 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -318,12 +318,12 @@ class RAGManager: self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.') return - if not self.ap.storage_mgr.storage_provider.exists(file_id): + if not await self.ap.storage_mgr.storage_provider.exists(file_id): self.ap.logger.error(f'File with ID {file_id} does not exist.') return self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.') # add new file record - file_to_update = File(id=file_id, kb_id=kb.id) + file_to_update = File(id=file_id, kb_id=kb.id, file_name=file_id, path=os.path.join('data', 'storage', file_id), file_type=os.path.splitext(file_id)[1].lstrip('.'), status=0) session.add(file_to_update) session.commit() self.ap.logger.info( From 234b61e2f8c0629b7c9de0eb575d8c28269d2196 Mon Sep 17 00:00:00 2001 From: WangCham <651122857@qq.com> Date: Sat, 12 Jul 2025 01:37:44 +0800 Subject: [PATCH 34/60] fix: add functions for deleting files --- pkg/api/http/controller/groups/knowledge/base.py | 2 +- pkg/rag/knowledge/mgr.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index b3fd50ea..50183f0f 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -93,6 +93,6 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): methods=['DELETE'], endpoint='delete_specific_file_in_kb', ) - async def delete_specific_file_in_kb(file_id: str) -> str: + async def delete_specific_file_in_kb(file_id: str,knowledge_base_uuid: str) -> str: await self.ap.knowledge_base_service.delete_data_by_file_id(file_id) return self.success({}) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index a5d5f513..6ebc85a7 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -207,20 +207,23 @@ class RAGManager: self.ap.logger.info(f'Starting data deletion process for file_id: {file_id}') session = SessionLocal() try: - # 1. 从 ChromaDB 删除 embeddings + # delete vectors await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}') - # 2. 删除与文件关联的 chunks 记录 chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() for chunk in chunks_to_delete: session.delete(chunk) self.ap.logger.info(f'Deleted {len(chunks_to_delete)} chunk records for file_id: {file_id}') - # 3. 删除文件记录本身 file_to_delete = session.query(File).filter_by(id=file_id).first() if file_to_delete: session.delete(file_to_delete) + try: + await self.ap.storage_mgr.storage_provider.delete(file_id) + except Exception as e: + self.ap.logger.error(f'Error deleting file from storage for file_id {file_id}: {str(e)}', exc_info=True) + await self.ap.storage_mgr.storage_provider.delete(file_id) self.ap.logger.info(f'Deleted file record for file_id: {file_id}') else: self.ap.logger.warning( From d78a329aa9bce20c525ca442f55d12120d303e23 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 12 Jul 2025 17:15:07 +0800 Subject: [PATCH 35/60] feat(fe): file uploading --- web/package-lock.json | 490 ++++++++++++++++-- web/package.json | 1 + .../components/kb-docs/FileUploadZone.tsx | 145 ++++++ .../knowledge/components/kb-docs/KBDoc.tsx | 45 +- .../kb-docs/doc-card/DocumentCard.tsx | 9 - .../components/kb-docs/documents/columns.tsx | 24 + .../kb-docs/documents/data-table.tsx | 81 +++ web/src/app/infra/http/HttpClient.ts | 12 +- web/src/components/ui/table.tsx | 116 +++++ web/src/i18n/locales/en-US.ts | 12 + web/src/i18n/locales/ja-JP.ts | 33 +- web/src/i18n/locales/zh-Hans.ts | 28 + 12 files changed, 937 insertions(+), 59 deletions(-) create mode 100644 web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx delete mode 100644 web/src/app/home/knowledge/components/kb-docs/doc-card/DocumentCard.tsx create mode 100644 web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx create mode 100644 web/src/app/home/knowledge/components/kb-docs/documents/data-table.tsx create mode 100644 web/src/components/ui/table.tsx diff --git a/web/package-lock.json b/web/package-lock.json index ee9b5767..fcc17852 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -12,23 +12,27 @@ "@dnd-kit/sortable": "^10.0.0", "@hookform/resolvers": "^5.0.1", "@radix-ui/react-checkbox": "^1.3.1", - "@radix-ui/react-dialog": "^1.1.13", + "@radix-ui/react-dialog": "^1.1.14", "@radix-ui/react-hover-card": "^1.1.13", "@radix-ui/react-label": "^2.1.6", "@radix-ui/react-popover": "^1.1.14", "@radix-ui/react-scroll-area": "^1.2.9", "@radix-ui/react-select": "^2.2.4", - "@radix-ui/react-slot": "^1.2.2", + "@radix-ui/react-separator": "^1.1.7", + "@radix-ui/react-slot": "^1.2.3", "@radix-ui/react-switch": "^1.2.4", "@radix-ui/react-tabs": "^1.1.11", "@radix-ui/react-toggle": "^1.1.8", "@radix-ui/react-toggle-group": "^1.1.9", + "@radix-ui/react-tooltip": "^1.2.7", "@tailwindcss/postcss": "^4.1.5", + "@tanstack/react-table": "^8.21.3", "axios": "^1.8.4", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "i18next": "^25.1.2", "i18next-browser-languagedetector": "^8.1.0", + "input-otp": "^1.4.2", "lodash": "^4.17.21", "lucide-react": "^0.507.0", "next": "15.2.4", @@ -1037,6 +1041,24 @@ } } }, + "node_modules/@radix-ui/react-collection/node_modules/@radix-ui/react-slot": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz", + "integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-compose-refs": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.2.tgz", @@ -1068,22 +1090,22 @@ } }, "node_modules/@radix-ui/react-dialog": { - "version": "1.1.13", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.13.tgz", - "integrity": "sha512-ARFmqUyhIVS3+riWzwGTe7JLjqwqgnODBUZdqpWar/z1WFs9z76fuOs/2BOWCR+YboRn4/WN9aoaGVwqNRr8VA==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.14.tgz", + "integrity": "sha512-+CpweKjqpzTmwRwcYECQcNYbI8V9VSQt0SNFKeEBLgfucbsLssU6Ppq7wUdNXEGb573bMjFhVjKVll8rmV6zMw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.9", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.6", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-portal": "1.1.8", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.2", - "@radix-ui/react-slot": "1.2.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-controllable-state": "1.2.2", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" @@ -1103,6 +1125,105 @@ } } }, + "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-dismissable-layer": { + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.10.tgz", + "integrity": "sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-escape-keydown": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-focus-scope": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz", + "integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-portal": { + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz", + "integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-direction": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@radix-ui/react-direction/-/react-direction-1.1.1.tgz", @@ -1448,24 +1569,6 @@ } } }, - "node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-slot": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", - "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", - "license": "MIT", - "dependencies": { - "@radix-ui/react-compose-refs": "1.1.2" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, "node_modules/@radix-ui/react-popper": { "version": "1.2.6", "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.6.tgz", @@ -1569,6 +1672,24 @@ } } }, + "node_modules/@radix-ui/react-primitive/node_modules/@radix-ui/react-slot": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz", + "integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-roving-focus": { "version": "1.1.9", "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.9.tgz", @@ -1654,24 +1775,6 @@ } } }, - "node_modules/@radix-ui/react-scroll-area/node_modules/@radix-ui/react-slot": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", - "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", - "license": "MIT", - "dependencies": { - "@radix-ui/react-compose-refs": "1.1.2" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, "node_modules/@radix-ui/react-select": { "version": "2.2.4", "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.4.tgz", @@ -1715,7 +1818,7 @@ } } }, - "node_modules/@radix-ui/react-slot": { + "node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-slot": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz", "integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==", @@ -1733,6 +1836,70 @@ } } }, + "node_modules/@radix-ui/react-separator": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.7.tgz", + "integrity": "sha512-0HEb8R9E8A+jZjvmFCy/J4xhbXy3TV+9XSnGJ3KvTtjlIUy/YQ/p6UYZvi7YbeoeXdyU9+Y3scizK6hkY37baA==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-separator/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slot": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", + "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-switch": { "version": "1.2.4", "resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.2.4.tgz", @@ -1846,6 +2013,192 @@ } } }, + "node_modules/@radix-ui/react-tooltip": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.7.tgz", + "integrity": "sha512-Ap+fNYwKTYJ9pzqW+Xe2HtMRbQ/EeWkj2qykZ6SuEV4iS/o1bZI5ssJbk4D2r8XuDuOBVz/tIx2JObtuqU+5Zw==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", + "@radix-ui/react-presence": "1.1.4", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", + "@radix-ui/react-use-controllable-state": "1.2.2", + "@radix-ui/react-visually-hidden": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-arrow": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.7.tgz", + "integrity": "sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-dismissable-layer": { + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.10.tgz", + "integrity": "sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-escape-keydown": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-popper": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.7.tgz", + "integrity": "sha512-IUFAccz1JyKcf/RjB552PlWwxjeCJB8/4KxT7EhBHOJM+mN7LdW+B3kacJXILm32xawcMMjb2i0cIZpo+f9kiQ==", + "license": "MIT", + "dependencies": { + "@floating-ui/react-dom": "^2.0.0", + "@radix-ui/react-arrow": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-layout-effect": "1.1.1", + "@radix-ui/react-use-rect": "1.1.1", + "@radix-ui/react-use-size": "1.1.1", + "@radix-ui/rect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-portal": { + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz", + "integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-visually-hidden": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.3.tgz", + "integrity": "sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-use-callback-ref": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.1.1.tgz", @@ -2295,6 +2648,39 @@ "tailwindcss": "4.1.5" } }, + "node_modules/@tanstack/react-table": { + "version": "8.21.3", + "resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.21.3.tgz", + "integrity": "sha512-5nNMTSETP4ykGegmVkhjcS8tTLW6Vl4axfEGQN3v0zdHYbK4UfoqfPChclTrJ4EoK9QynqAu9oUf8VEmrpZ5Ww==", + "license": "MIT", + "dependencies": { + "@tanstack/table-core": "8.21.3" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + }, + "peerDependencies": { + "react": ">=16.8", + "react-dom": ">=16.8" + } + }, + "node_modules/@tanstack/table-core": { + "version": "8.21.3", + "resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.21.3.tgz", + "integrity": "sha512-ldZXEhOBb8Is7xLs01fR3YEc3DERiz5silj8tnGkFZytt1abEvl/GhUmCE0PMLaMPTa3Jk4HbKmRlHmu+gCftg==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + } + }, "node_modules/@tybys/wasm-util": { "version": "0.9.0", "resolved": "https://registry.npmjs.org/@tybys/wasm-util/-/wasm-util-0.9.0.tgz", @@ -4763,6 +5149,16 @@ "node": ">=0.8.19" } }, + "node_modules/input-otp": { + "version": "1.4.2", + "resolved": "https://registry.npmjs.org/input-otp/-/input-otp-1.4.2.tgz", + "integrity": "sha512-l3jWwYNvrEa6NTCt7BECfCm48GvwuZzkoeG3gBL2w4CHeOXW3eKFmf9UNYkNfYc3mxMrthMnxjIE07MT0zLBQA==", + "license": "MIT", + "peerDependencies": { + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0.0 || ^19.0.0-rc" + } + }, "node_modules/internal-slot": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.1.0.tgz", diff --git a/web/package.json b/web/package.json index 458e4132..d5e8542c 100644 --- a/web/package.json +++ b/web/package.json @@ -35,6 +35,7 @@ "@radix-ui/react-toggle-group": "^1.1.9", "@radix-ui/react-tooltip": "^1.2.7", "@tailwindcss/postcss": "^4.1.5", + "@tanstack/react-table": "^8.21.3", "axios": "^1.8.4", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", diff --git a/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx b/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx new file mode 100644 index 00000000..8c072bdf --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx @@ -0,0 +1,145 @@ +import React, { useCallback, useState } from 'react'; +import { Card, CardContent } from '@/components/ui/card'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { toast } from 'sonner'; +import { useTranslation } from 'react-i18next'; + +interface FileUploadZoneProps { + kbId: string; + onUploadSuccess: () => void; + onUploadError: (error: string) => void; +} + +export default function FileUploadZone({ + kbId, + onUploadSuccess, + onUploadError, +}: FileUploadZoneProps) { + const { t } = useTranslation(); + const [isDragOver, setIsDragOver] = useState(false); + const [isUploading, setIsUploading] = useState(false); + + const handleUpload = useCallback( + async (file: File) => { + if (isUploading) return; + + setIsUploading(true); + const toastId = toast.loading(t('knowledge.documentsTab.uploadingFile')); + + try { + // Step 1: Upload file to server + const uploadResult = await httpClient.uploadDocumentFile(file); + + // Step 2: Associate file with knowledge base + await httpClient.uploadKnowledgeBaseFile(kbId, uploadResult.file_id); + + toast.success(t('knowledge.documentsTab.uploadSuccess'), { + id: toastId, + }); + onUploadSuccess(); + } catch (error) { + console.error('File upload failed:', error); + const errorMessage = t('knowledge.documentsTab.uploadError'); + toast.error(errorMessage, { id: toastId }); + onUploadError(errorMessage); + } finally { + setIsUploading(false); + } + }, + [kbId, isUploading, onUploadSuccess, onUploadError], + ); + + const handleDragOver = useCallback((e: React.DragEvent) => { + e.preventDefault(); + setIsDragOver(true); + }, []); + + const handleDragLeave = useCallback((e: React.DragEvent) => { + e.preventDefault(); + setIsDragOver(false); + }, []); + + const handleDrop = useCallback( + (e: React.DragEvent) => { + e.preventDefault(); + setIsDragOver(false); + + const files = Array.from(e.dataTransfer.files); + if (files.length > 0) { + handleUpload(files[0]); + } + }, + [handleUpload], + ); + + const handleFileSelect = useCallback( + (e: React.ChangeEvent) => { + const files = e.target.files; + if (files && files.length > 0) { + handleUpload(files[0]); + } + }, + [handleUpload], + ); + + return ( + + +
+ + + +
+
+
+ ); +} diff --git a/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx b/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx index 5cc9a850..b1730602 100644 --- a/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx +++ b/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx @@ -1,5 +1,48 @@ import { useEffect, useState } from 'react'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { KnowledgeBaseFile } from '@/app/infra/entities/api'; +import { columns, DocumentFile } from './documents/columns'; +import { DataTable } from './documents/data-table'; +import FileUploadZone from './FileUploadZone'; export default function KBDoc({ kbId }: { kbId: string }) { - return
Documents
; + const [documentsList, setDocumentsList] = useState([]); + + useEffect(() => { + getDocumentsList(); + }, []); + + async function getDocumentsList() { + const resp = await httpClient.getKnowledgeBaseFiles(kbId); + setDocumentsList( + resp.files.map((file: KnowledgeBaseFile) => { + return { + id: file.file_id, + name: file.file_name, + status: file.status, + }; + }), + ); + } + + const handleUploadSuccess = () => { + // Refresh document list after successful upload + getDocumentsList(); + }; + + const handleUploadError = (error: string) => { + // Error messages are already handled by toast in FileUploadZone component + console.error('Upload failed:', error); + }; + + return ( +
+ + +
+ ); } diff --git a/web/src/app/home/knowledge/components/kb-docs/doc-card/DocumentCard.tsx b/web/src/app/home/knowledge/components/kb-docs/doc-card/DocumentCard.tsx deleted file mode 100644 index 23a884ba..00000000 --- a/web/src/app/home/knowledge/components/kb-docs/doc-card/DocumentCard.tsx +++ /dev/null @@ -1,9 +0,0 @@ -export default function DocumentCard({ - kbId, - fileId, -}: { - kbId: string; - fileId: string; -}) { - return
; -} diff --git a/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx b/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx new file mode 100644 index 00000000..d43afd68 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx @@ -0,0 +1,24 @@ +'use client'; + +import { ColumnDef } from '@tanstack/react-table'; +import { useTranslation } from 'react-i18next'; + +export type DocumentFile = { + id: string; + name: string; + status: string; +}; + +export const columns = (): ColumnDef[] => { + const { t } = useTranslation(); + return [ + { + accessorKey: 'name', + header: t('knowledge.documentsTab.name'), + }, + { + accessorKey: 'status', + header: t('knowledge.documentsTab.status'), + }, + ]; +}; diff --git a/web/src/app/home/knowledge/components/kb-docs/documents/data-table.tsx b/web/src/app/home/knowledge/components/kb-docs/documents/data-table.tsx new file mode 100644 index 00000000..178ccad9 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-docs/documents/data-table.tsx @@ -0,0 +1,81 @@ +'use client'; + +import { + ColumnDef, + flexRender, + getCoreRowModel, + useReactTable, +} from '@tanstack/react-table'; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table'; +import { useTranslation } from 'react-i18next'; + +interface DataTableProps { + columns: ColumnDef[]; + data: TData[]; +} + +export function DataTable({ + columns, + data, +}: DataTableProps) { + const { t } = useTranslation(); + const table = useReactTable({ + data, + columns, + getCoreRowModel: getCoreRowModel(), + }); + + return ( +
+ + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => { + return ( + + {header.isPlaceholder + ? null + : flexRender( + header.column.columnDef.header, + header.getContext(), + )} + + ); + })} + + ))} + + + {table.getRowModel().rows?.length ? ( + table.getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ))} + + )) + ) : ( + + + {t('knowledge.documentsTab.noResults')} + + + )} + +
+
+ ); +} diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 8842b04d..3a0b5f35 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -433,7 +433,17 @@ class HttpClient { // ============ File management API ============ public uploadDocumentFile(file: File): Promise<{ file_id: string }> { - return this.post('/api/v1/files/documents', file); + const formData = new FormData(); + formData.append('file', file); + + return this.request<{ file_id: string }>({ + method: 'post', + url: '/api/v1/files/documents', + data: formData, + headers: { + 'Content-Type': 'multipart/form-data', + }, + }); } // ============ Knowledge Base API ============ diff --git a/web/src/components/ui/table.tsx b/web/src/components/ui/table.tsx new file mode 100644 index 00000000..ebded8ed --- /dev/null +++ b/web/src/components/ui/table.tsx @@ -0,0 +1,116 @@ +'use client'; + +import * as React from 'react'; + +import { cn } from '@/lib/utils'; + +function Table({ className, ...props }: React.ComponentProps<'table'>) { + return ( +
+ + + ); +} + +function TableHeader({ className, ...props }: React.ComponentProps<'thead'>) { + return ( + + ); +} + +function TableBody({ className, ...props }: React.ComponentProps<'tbody'>) { + return ( + + ); +} + +function TableFooter({ className, ...props }: React.ComponentProps<'tfoot'>) { + return ( + tr]:last:border-b-0', + className, + )} + {...props} + /> + ); +} + +function TableRow({ className, ...props }: React.ComponentProps<'tr'>) { + return ( + + ); +} + +function TableHead({ className, ...props }: React.ComponentProps<'th'>) { + return ( +
[role=checkbox]]:translate-y-[2px]', + className, + )} + {...props} + /> + ); +} + +function TableCell({ className, ...props }: React.ComponentProps<'td'>) { + return ( + [role=checkbox]]:translate-y-[2px]', + className, + )} + {...props} + /> + ); +} + +function TableCaption({ + className, + ...props +}: React.ComponentProps<'caption'>) { + return ( +
+ ); +} + +export { + Table, + TableHeader, + TableBody, + TableFooter, + TableHead, + TableRow, + TableCell, + TableCaption, +}; diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index ecc43204..cfb50966 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -251,6 +251,18 @@ const enUS = { embeddingModelDescription: 'Used to vectorize the text, you can configure it in the Models page', updateTime: 'Updated ', + documentsTab: { + name: 'Name', + status: 'Status', + noResults: 'No results', + dragAndDrop: 'Drag and drop files here or click to upload', + uploading: 'Uploading...', + supportedFormats: + 'Supports PDF, Word, TXT, Markdown and other document formats', + uploadSuccess: 'File uploaded successfully!', + uploadError: 'File upload failed, please try again', + uploadingFile: 'Uploading file...', + }, }, register: { title: 'Initialize LangBot 👋', diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index 21b0ff7d..639549b1 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -234,7 +234,38 @@ const jaJP = { }, knowledge: { title: '知識ベース', - description: 'LLMの応答品質を向上させるための知識ベースを設定します', + createKnowledgeBase: '知識ベースを作成', + editKnowledgeBase: '知識ベースを編集', + editDocument: 'ドキュメント', + description: 'LLMの回答品質向上のための知識ベースを設定します', + metadata: 'メタデータ', + documents: 'ドキュメント', + kbNameRequired: '知識ベース名は必須です', + kbDescriptionRequired: '知識ベースの説明は必須です', + embeddingModelUUIDRequired: '埋め込みモデルは必須です', + daysAgo: '日前', + today: '今日', + kbName: '知識ベース名', + kbDescription: '知識ベースの説明', + defaultDescription: '知識ベース', + embeddingModelUUID: '埋め込みモデル', + selectEmbeddingModel: '埋め込みモデルを選択', + embeddingModelDescription: + 'テキストのベクトル化に使用する埋め込みモデルを管理します', + updateTime: '更新日時', + documentsTab: { + name: '名前', + status: 'ステータス', + noResults: '結果がありません', + dragAndDrop: + 'ファイルをここにドラッグ&ドロップするか、クリックしてアップロードしてください', + uploading: 'アップロード中...', + supportedFormats: + 'PDF、Word、TXT、Markdownなどのドキュメントファイルをサポートしています', + uploadSuccess: 'ファイルのアップロードに成功しました!', + uploadError: 'ファイルのアップロードに失敗しました。再度お試しください', + uploadingFile: 'ファイルをアップロード中...', + }, }, register: { title: 'LangBot を初期化 👋', diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 1bd04ca8..71089fa2 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -227,7 +227,35 @@ const zhHans = { }, knowledge: { title: '知识库', + createKnowledgeBase: '创建知识库', + editKnowledgeBase: '编辑知识库', + editDocument: '文档', description: '配置可用于提升模型回复质量的知识库', + metadata: '元数据', + documents: '文档', + kbNameRequired: '知识库名称不能为空', + kbDescriptionRequired: '知识库描述不能为空', + embeddingModelUUIDRequired: '嵌入模型不能为空', + daysAgo: '天前', + today: '今天', + kbName: '知识库名称', + kbDescription: '知识库描述', + defaultDescription: '一个知识库', + embeddingModelUUID: '嵌入模型', + selectEmbeddingModel: '选择嵌入模型', + embeddingModelDescription: '用于向量化文本,可在模型配置页面配置', + updateTime: '更新于', + documentsTab: { + name: '名称', + status: '状态', + noResults: '暂无结果', + dragAndDrop: '拖拽文件到此处或点击上传', + uploading: '上传中...', + supportedFormats: '支持 PDF、Word、TXT、Markdown 等文档格式', + uploadSuccess: '文件上传成功!', + uploadError: '文件上传失败,请重试', + uploadingFile: '上传文件中...', + }, }, register: { title: '初始化 LangBot 👋', From 1e85d02ae4a328b3c8f96baf198070cf99945cce Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 12 Jul 2025 17:29:39 +0800 Subject: [PATCH 36/60] perf: adjust ui --- .../components/kb-docs/FileUploadZone.tsx | 14 +++++++------- .../home/knowledge/components/kb-docs/KBDoc.tsx | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx b/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx index 8c072bdf..aa8adede 100644 --- a/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx +++ b/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx @@ -84,10 +84,10 @@ export default function FileUploadZone({ return ( - +