diff --git a/pkg/api/http/controller/groups/provider/models.py b/pkg/api/http/controller/groups/provider/models.py index bb77986c..0de0c922 100644 --- a/pkg/api/http/controller/groups/provider/models.py +++ b/pkg/api/http/controller/groups/provider/models.py @@ -9,18 +9,18 @@ class LLMModelsRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={'models': await self.ap.model_service.get_llm_models()}) + return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()}) elif quart.request.method == 'POST': json_data = await quart.request.json - model_uuid = await self.ap.model_service.create_llm_model(json_data) + model_uuid = await self.ap.llm_model_service.create_llm_model(json_data) return self.success(data={'uuid': model_uuid}) @self.route('/', methods=['GET', 'PUT', 'DELETE']) async def _(model_uuid: str) -> str: if quart.request.method == 'GET': - model = await self.ap.model_service.get_llm_model(model_uuid) + model = await self.ap.llm_model_service.get_llm_model(model_uuid) if model is None: return self.http_status(404, -1, 'model not found') @@ -29,11 +29,11 @@ class LLMModelsRouterGroup(group.RouterGroup): elif quart.request.method == 'PUT': json_data = await quart.request.json - await self.ap.model_service.update_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.update_llm_model(model_uuid, json_data) return self.success() elif quart.request.method == 'DELETE': - await self.ap.model_service.delete_llm_model(model_uuid) + await self.ap.llm_model_service.delete_llm_model(model_uuid) return self.success() @@ -41,6 +41,49 @@ class LLMModelsRouterGroup(group.RouterGroup): async def _(model_uuid: str) -> str: json_data = await quart.request.json - await self.ap.model_service.test_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.test_llm_model(model_uuid, json_data) + + return self.success() + + +@group.group_class('models/embedding', '/api/v1/provider/models/embedding') +class EmbeddingModelsRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('', methods=['GET', 'POST']) + async def _() -> str: + if quart.request.method == 'GET': + return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + + model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data) + + return self.success(data={'uuid': model_uuid}) + + @self.route('/', methods=['GET', 'PUT', 'DELETE']) + async def _(model_uuid: str) -> str: + if quart.request.method == 'GET': + model = await self.ap.embedding_models_service.get_embedding_model(model_uuid) + + if model is None: + return self.http_status(404, -1, 'model not found') + + return self.success(data={'model': model}) + elif quart.request.method == 'PUT': + json_data = await quart.request.json + + await self.ap.embedding_models_service.update_embedding_model(model_uuid, json_data) + + return self.success() + elif quart.request.method == 'DELETE': + await self.ap.embedding_models_service.delete_embedding_model(model_uuid) + + return self.success() + + @self.route('//test', methods=['POST']) + async def _(model_uuid: str) -> str: + json_data = await quart.request.json + + await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data) return self.success() diff --git a/pkg/api/http/controller/groups/provider/requesters.py b/pkg/api/http/controller/groups/provider/requesters.py index 0f999288..af9e1540 100644 --- a/pkg/api/http/controller/groups/provider/requesters.py +++ b/pkg/api/http/controller/groups/provider/requesters.py @@ -8,7 +8,8 @@ class RequestersRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> quart.Response: - return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()}) + model_type = quart.request.args.get('type', '') + return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info(model_type)}) @self.route('/', methods=['GET']) async def _(requester_name: str) -> quart.Response: diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 74fb4e02..afeae3eb 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -10,7 +10,7 @@ from ....provider.modelmgr import requester as model_requester from ....provider import entities as llm_entities -class ModelsService: +class LLMModelsService: ap: app.Application def __init__(self, ap: app.Application) -> None: @@ -103,3 +103,90 @@ class ModelsService: funcs=[], extra_args={}, ) + + +class EmbeddingModelsService: + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_embedding_models(self) -> list[dict]: + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + + models = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models] + + async def create_embedding_model(self, model_data: dict) -> str: + model_data['uuid'] = str(uuid.uuid4()) + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data) + ) + + embedding_model = await self.get_embedding_model(model_data['uuid']) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + return model_data['uuid'] + + async def get_embedding_model(self, model_uuid: str) -> dict | None: + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + model = result.first() + + if model is None: + return None + + return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) + + async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None: + if 'uuid' in model_data: + del model_data['uuid'] + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_model.EmbeddingModel) + .where(persistence_model.EmbeddingModel.uuid == model_uuid) + .values(**model_data) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + embedding_model = await self.get_embedding_model(model_uuid) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + async def delete_embedding_model(self, model_uuid: str) -> None: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None: + runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None + + if model_uuid != '_': + for model in self.ap.model_mgr.embedding_models: + if model.model_entity.uuid == model_uuid: + runtime_embedding_model = model + break + + if runtime_embedding_model is None: + raise Exception('model not found') + + else: + runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data) + + await runtime_embedding_model.requester.invoke_embedding( + query=None, + model=runtime_embedding_model, + input_text='Hello, world!', + extra_args={}, + ) diff --git a/pkg/core/app.py b/pkg/core/app.py index 911acd3d..318cddcb 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -103,7 +103,9 @@ class Application: user_service: user_service.UserService = None - model_service: model_service.ModelsService = None + llm_model_service: model_service.LLMModelsService = None + + embedding_models_service: model_service.EmbeddingModelsService = None pipeline_service: pipeline_service.PipelineService = None diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 6ee35610..482a468b 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -95,8 +95,11 @@ class BuildAppStage(stage.BootingStage): user_service_inst = user_service.UserService(ap) ap.user_service = user_service_inst - model_service_inst = model_service.ModelsService(ap) - ap.model_service = model_service_inst + llm_model_service_inst = model_service.LLMModelsService(ap) + ap.llm_model_service = llm_model_service_inst + + embedding_models_service_inst = model_service.EmbeddingModelsService(ap) + ap.embedding_models_service = embedding_models_service_inst pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst diff --git a/pkg/entity/persistence/model.py b/pkg/entity/persistence/model.py index 9eb2ccef..418cab70 100644 --- a/pkg/entity/persistence/model.py +++ b/pkg/entity/persistence/model.py @@ -23,3 +23,24 @@ class LLMModel(Base): server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now(), ) + + +class EmbeddingModel(Base): + """Embedding 模型""" + + __tablename__ = 'embedding_models' + + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) + extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/provider/modelmgr/entities.py b/pkg/provider/modelmgr/entities.py index cf856894..7bc02a32 100644 --- a/pkg/provider/modelmgr/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel): token_mgr: token.TokenManager - requester: requester.LLMAPIRequester + requester: requester.ProviderAPIRequester tool_call_supported: typing.Optional[bool] = False diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index b15e53a9..2c92eacc 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -18,7 +18,7 @@ class ModelManager: model_list: list[entities.LLMModelInfo] # deprecated - requesters: dict[str, requester.LLMAPIRequester] # deprecated + requesters: dict[str, requester.ProviderAPIRequester] # deprecated token_mgrs: dict[str, token.TokenManager] # deprecated @@ -28,9 +28,11 @@ class ModelManager: llm_models: list[requester.RuntimeLLMModel] + embedding_models: list[requester.RuntimeEmbeddingModel] + requester_components: list[engine.Component] - requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache + requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache def __init__(self, ap: app.Application): self.ap = ap @@ -38,6 +40,7 @@ class ModelManager: self.requesters = {} self.token_mgrs = {} self.llm_models = [] + self.embedding_models = [] self.requester_components = [] self.requester_dict = {} @@ -45,7 +48,7 @@ class ModelManager: self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') # forge requester class dict - requester_dict: dict[str, type[requester.LLMAPIRequester]] = {} + requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {} for component in self.requester_components: requester_dict[component.metadata.name] = component.get_python_component_class() @@ -58,13 +61,11 @@ class ModelManager: self.ap.logger.info('Loading models from db...') self.llm_models = [] + self.embedding_models = [] # llm models result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) - llm_models = result.all() - - # load models for llm_model in llm_models: try: await self.load_llm_model(llm_model) @@ -73,11 +74,17 @@ class ModelManager: except Exception as e: self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}') + # embedding models + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + embedding_models = result.all() + for embedding_model in embedding_models: + await self.load_embedding_model(embedding_model) + async def init_runtime_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """初始化运行时模型""" + """初始化运行时 LLM 模型""" if isinstance(model_info, sqlalchemy.Row): model_info = persistence_model.LLMModel(**model_info._mapping) elif isinstance(model_info, dict): @@ -101,14 +108,47 @@ class ModelManager: return runtime_llm_model + async def init_runtime_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """初始化运行时 Embedding 模型""" + if isinstance(model_info, sqlalchemy.Row): + model_info = persistence_model.EmbeddingModel(**model_info._mapping) + elif isinstance(model_info, dict): + model_info = persistence_model.EmbeddingModel(**model_info) + + requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config) + + await requester_inst.initialize() + + runtime_embedding_model = requester.RuntimeEmbeddingModel( + model_entity=model_info, + token_mgr=token.TokenManager( + name=model_info.uuid, + tokens=model_info.api_keys, + ), + requester=requester_inst, + ) + + return runtime_embedding_model + async def load_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """加载模型""" + """加载 LLM 模型""" runtime_llm_model = await self.init_runtime_llm_model(model_info) self.llm_models.append(runtime_llm_model) + async def load_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """加载 Embedding 模型""" + runtime_embedding_model = await self.init_runtime_embedding_model(model_info) + self.embedding_models.append(runtime_embedding_model) + async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated """通过名称获取模型""" for model in self.model_list: @@ -116,23 +156,44 @@ class ModelManager: return model raise ValueError(f'无法确定模型 {name} 的信息') - async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo: - """通过uuid获取模型""" + async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel: + """通过uuid获取 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == uuid: return model - raise ValueError(f'model {uuid} not found') + raise ValueError(f'LLM model {uuid} not found') + + async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel: + """通过uuid获取 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == uuid: + return model + raise ValueError(f'Embedding model {uuid} not found') async def remove_llm_model(self, model_uuid: str): - """移除模型""" + """移除 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == model_uuid: self.llm_models.remove(model) return - def get_available_requesters_info(self) -> list[dict]: + async def remove_embedding_model(self, model_uuid: str): + """移除 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == model_uuid: + self.embedding_models.remove(model) + return + + def get_available_requesters_info(self, model_type: str) -> list[dict]: """获取所有可用的请求器""" - return [component.to_plain_dict() for component in self.requester_components] + if model_type != '': + return [ + component.to_plain_dict() + for component in self.requester_components + if model_type in component.spec['support_type'] + ] + else: + return [component.to_plain_dict() for component in self.requester_components] def get_available_requester_info_by_name(self, name: str) -> dict | None: """通过名称获取请求器信息""" diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 244f4c82..9742a52c 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -20,22 +20,45 @@ class RuntimeLLMModel: token_mgr: token.TokenManager """api key管理器""" - requester: LLMAPIRequester + requester: ProviderAPIRequester """请求器实例""" def __init__( self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, - requester: LLMAPIRequester, + requester: ProviderAPIRequester, ): self.model_entity = model_entity self.token_mgr = token_mgr self.requester = requester -class LLMAPIRequester(metaclass=abc.ABCMeta): - """LLM API请求器""" +class RuntimeEmbeddingModel: + """运行时 Embedding 模型""" + + model_entity: persistence_model.EmbeddingModel + """模型数据""" + + token_mgr: token.TokenManager + """api key管理器""" + + requester: ProviderAPIRequester + """请求器实例""" + + def __init__( + self, + model_entity: persistence_model.EmbeddingModel, + token_mgr: token.TokenManager, + requester: ProviderAPIRequester, + ): + self.model_entity = model_entity + self.token_mgr = token_mgr + self.requester = requester + + +class ProviderAPIRequester(metaclass=abc.ABCMeta): + """Provider API请求器""" name: str = None @@ -74,3 +97,23 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): llm_entities.Message: 返回消息对象 """ pass + + async def invoke_embedding( + self, + query: core_entities.Query, + model: RuntimeEmbeddingModel, + input_text: str, + extra_args: dict[str, typing.Any] = {}, + ) -> list[float]: + """调用 Embedding API + + Args: + query (core_entities.Query): 请求上下文 + model (RuntimeEmbeddingModel): 使用的模型信息 + input_text (str): 输入文本 + extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. + + Returns: + list[float]: 返回的 embedding 向量 + """ + pass diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 38573854..b195ae51 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -15,7 +15,7 @@ from ...tools import entities as tools_entities from ....utils import image -class AnthropicMessages(requester.LLMAPIRequester): +class AnthropicMessages(requester.ProviderAPIRequester): """Anthropic Messages API 请求器""" client: anthropic.AsyncAnthropic diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml index c124fed9..7dbcf3ed 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./anthropicmsgs.py diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml index 24beb915..10aae30f 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./bailianchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 513086e5..98d1f13a 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -13,7 +13,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class OpenAIChatCompletions(requester.LLMAPIRequester): +class OpenAIChatCompletions(requester.ProviderAPIRequester): """OpenAI ChatCompletion API 请求器""" client: openai.AsyncClient @@ -141,3 +141,39 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def invoke_embedding( + self, + query: core_entities.Query, + model: requester.RuntimeEmbeddingModel, + input_text: str, + extra_args: dict[str, typing.Any] = {}, + ) -> list[float]: + """调用 Embedding API""" + self.client.api_key = model.token_mgr.get_token() + + args = { + 'model': model.model_entity.name, + 'input': input_text, + } + + if model.model_entity.extra_args: + args.update(model.model_entity.extra_args) + + args.update(extra_args) + + try: + resp = await self.client.embeddings.create(**args) + return resp.data[0].embedding + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + except openai.BadRequestError as e: + raise errors.RequesterError(f'请求参数错误: {e.message}') + except openai.AuthenticationError as e: + raise errors.RequesterError(f'无效的 api-key: {e.message}') + except openai.NotFoundError as e: + raise errors.RequesterError(f'请求路径错误: {e.message}') + except openai.RateLimitError as e: + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') + except openai.APIError as e: + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.yaml b/pkg/provider/modelmgr/requesters/chatcmpl.yaml index 908b30ac..ff0de6f9 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/chatcmpl.yaml @@ -22,6 +22,9 @@ spec: type: integer required: true default: 120 + support_type: + - llm + - text-embedding execution: python: path: ./chatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml index ea2c7eea..6f320e66 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./deepseekchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml index 6bfc085e..73fca19c 100644 --- a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./geminichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml index a18675a1..3a79bb49 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./giteeaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml index 893235b2..fbe57dad 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./lmstudiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index b8868f4d..4708f671 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -14,7 +14,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class ModelScopeChatCompletions(requester.LLMAPIRequester): +class ModelScopeChatCompletions(requester.ProviderAPIRequester): """ModelScope ChatCompletion API 请求器""" client: openai.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml index a641a672..a926d889 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./modelscopechatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml index f3ae73c8..52f7bcda 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./moonshotchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index 2ea4bb7d..1456515f 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -17,7 +17,7 @@ from ....core import entities as core_entities REQUESTER_NAME: str = 'ollama-chat' -class OllamaChatCompletions(requester.LLMAPIRequester): +class OllamaChatCompletions(requester.ProviderAPIRequester): """Ollama平台 ChatCompletion API请求器""" client: ollama.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/ollamachat.yaml b/pkg/provider/modelmgr/requesters/ollamachat.yaml index 01435775..f4c4bf5a 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.yaml +++ b/pkg/provider/modelmgr/requesters/ollamachat.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./ollamachat.py diff --git a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml index 2ecee6cc..ea35bce6 100644 --- a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./openrouterchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml index 9f201aa9..a5a3421c 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./ppiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml index 19b3dcc3..3872cb6f 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./siliconflowchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml index 402f04e7..c711ef2d 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./volcarkchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml index 29db4eb3..2769a402 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./xaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml index a05184ef..34539d95 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./zhipuaichatcmpl.py diff --git a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx index e21317d6..ef9c6f45 100644 --- a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx +++ b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx @@ -47,6 +47,7 @@ export const sidebarConfigList = [ zh_Hans: 'https://docs.langbot.app/zh/deploy/models/readme.html', }, }), + new SidebarChildVO({ id: 'pipelines', name: t('pipelines.title'), diff --git a/web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts b/web/src/app/home/models/component/ChooseRequesterEntity.ts similarity index 100% rename from web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts rename to web/src/app/home/models/component/ChooseRequesterEntity.ts diff --git a/web/src/app/home/models/component/ICreateEmbeddingField.ts b/web/src/app/home/models/component/ICreateEmbeddingField.ts new file mode 100644 index 00000000..ea198f3f --- /dev/null +++ b/web/src/app/home/models/component/ICreateEmbeddingField.ts @@ -0,0 +1,7 @@ +export interface ICreateEmbeddingField { + name: string; + model_provider: string; + url: string; + api_key: string; + extra_args?: string[]; +} diff --git a/web/src/app/home/models/ICreateLLMField.ts b/web/src/app/home/models/component/ICreateLLMField.ts similarity index 100% rename from web/src/app/home/models/ICreateLLMField.ts rename to web/src/app/home/models/component/ICreateLLMField.ts diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css new file mode 100644 index 00000000..9c6c54f7 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css @@ -0,0 +1,97 @@ +.cardContainer { + width: 100%; + height: 10rem; + background-color: #fff; + border-radius: 10px; + box-shadow: 0px 2px 2px 0 rgba(0, 0, 0, 0.2); + padding: 1.2rem; + cursor: pointer; +} + +.cardContainer:hover { + box-shadow: 0px 2px 8px 0 rgba(0, 0, 0, 0.1); +} + +.iconBasicInfoContainer { + width: 100%; + height: 100%; + display: flex; + flex-direction: row; + gap: 0.8rem; + user-select: none; +} + +.iconImage { + width: 3.8rem; + height: 3.8rem; + margin: 0.2rem; + border-radius: 50%; +} + +.basicInfoContainer { + display: flex; + flex-direction: column; + gap: 0.2rem; + min-width: 0; + width: 100%; +} + +.basicInfoText { + font-size: 1.4rem; + font-weight: bold; +} + +.providerContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; +} + +.providerIcon { + width: 1.2rem; + height: 1.2rem; + margin-top: 0.2rem; + color: #626262; +} + +.providerLabel { + font-size: 1.2rem; + font-weight: 600; + color: #626262; +} + +.baseURLContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; + width: calc(100% - 3rem); +} + +.baseURLIcon { + width: 1.2rem; + height: 1.2rem; + color: #626262; +} + +.baseURLText { + font-size: 1rem; + width: 100%; + color: #626262; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + max-width: 100%; +} + +.bigText { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + font-size: 1.4rem; + font-weight: bold; + max-width: 100%; +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx new file mode 100644 index 00000000..e3dfaf80 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx @@ -0,0 +1,53 @@ +import styles from './EmbeddingCard.module.css'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; + +export default function EmbeddingCard({ cardVO }: { cardVO: EmbeddingCardVO }) { + return ( +
+
+ icon + +
+ {/* 名称 */} +
+ {cardVO.name} +
+ {/* 厂商 */} +
+ + + + + {cardVO.providerLabel} + +
+ {/* baseURL */} +
+ + + + {cardVO.baseURL} +
+
+
+
+ ); +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts new file mode 100644 index 00000000..f6d960f6 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts @@ -0,0 +1,23 @@ +export interface IEmbeddingCardVO { + id: string; + iconURL: string; + name: string; + providerLabel: string; + baseURL: string; +} + +export class EmbeddingCardVO implements IEmbeddingCardVO { + id: string; + iconURL: string; + providerLabel: string; + name: string; + baseURL: string; + + constructor(props: IEmbeddingCardVO) { + this.id = props.id; + this.iconURL = props.iconURL; + this.providerLabel = props.providerLabel; + this.name = props.name; + this.baseURL = props.baseURL; + } +} diff --git a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx new file mode 100644 index 00000000..4658a22f --- /dev/null +++ b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx @@ -0,0 +1,563 @@ +import { ICreateEmbeddingField } from '@/app/home/models/component/ICreateEmbeddingField'; +import { useEffect, useState } from 'react'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { EmbeddingModel } from '@/app/infra/entities/api'; +import { UUID } from 'uuidjs'; + +import { zodResolver } from '@hookform/resolvers/zod'; +import { useForm } from 'react-hook-form'; +import { z } from 'zod'; +import { useTranslation } from 'react-i18next'; + +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogFooter, +} from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { toast } from 'sonner'; +import { i18nObj } from '@/i18n/I18nProvider'; + +const getExtraArgSchema = (t: (key: string) => string) => + z + .object({ + key: z.string().min(1, { message: t('models.keyNameRequired') }), + type: z.enum(['string', 'number', 'boolean']), + value: z.string(), + }) + .superRefine((data, ctx) => { + if (data.type === 'number' && isNaN(Number(data.value))) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeValidNumber'), + path: ['value'], + }); + } + if ( + data.type === 'boolean' && + data.value !== 'true' && + data.value !== 'false' + ) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeTrueOrFalse'), + path: ['value'], + }); + } + }); + +const getFormSchema = (t: (key: string) => string) => + z.object({ + name: z.string().min(1, { message: t('models.modelNameRequired') }), + model_provider: z + .string() + .min(1, { message: t('models.modelProviderRequired') }), + url: z.string().min(1, { message: t('models.requestURLRequired') }), + api_key: z.string().min(1, { message: t('models.apiKeyRequired') }), + extra_args: z.array(getExtraArgSchema(t)).optional(), + }); + +export default function EmbeddingForm({ + editMode, + initEmbeddingId, + onFormSubmit, + onFormCancel, + onEmbeddingDeleted, +}: { + editMode: boolean; + initEmbeddingId?: string; + onFormSubmit: () => void; + onFormCancel: () => void; + onEmbeddingDeleted: () => void; +}) { + const { t } = useTranslation(); + const formSchema = getFormSchema(t); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + name: '', + model_provider: '', + url: '', + api_key: 'sk-xxxxx', + extra_args: [], + }, + }); + + const [extraArgs, setExtraArgs] = useState< + { key: string; type: 'string' | 'number' | 'boolean'; value: string }[] + >([]); + + const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false); + const [requesterNameList, setRequesterNameList] = useState< + IChooseRequesterEntity[] + >([]); + const [requesterDefaultURLList, setRequesterDefaultURLList] = useState< + string[] + >([]); + const [modelTesting, setModelTesting] = useState(false); + + useEffect(() => { + initEmbeddingModelFormComponent().then(() => { + if (editMode && initEmbeddingId) { + getEmbeddingConfig(initEmbeddingId).then((val) => { + form.setValue('name', val.name); + form.setValue('model_provider', val.model_provider); + // setCurrentModelProvider(val.model_provider); + form.setValue('url', val.url); + form.setValue('api_key', val.api_key); + if (val.extra_args) { + const args = val.extra_args.map((arg) => { + const [key, value] = arg.split(':'); + let type: 'string' | 'number' | 'boolean' = 'string'; + if (!isNaN(Number(value))) { + type = 'number'; + } else if (value === 'true' || value === 'false') { + type = 'boolean'; + } + return { + key, + type, + value, + }; + }); + setExtraArgs(args); + form.setValue('extra_args', args); + } + }); + } else { + form.reset(); + } + }); + }, []); + + const addExtraArg = () => { + setExtraArgs([...extraArgs, { key: '', type: 'string', value: '' }]); + }; + + const updateExtraArg = ( + index: number, + field: 'key' | 'type' | 'value', + value: string, + ) => { + const newArgs = [...extraArgs]; + newArgs[index] = { + ...newArgs[index], + [field]: value, + }; + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + const removeExtraArg = (index: number) => { + const newArgs = extraArgs.filter((_, i) => i !== index); + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + async function initEmbeddingModelFormComponent() { + const requesterNameList = + await httpClient.getProviderRequesters('text-embedding'); + setRequesterNameList( + requesterNameList.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }), + ); + setRequesterDefaultURLList( + requesterNameList.requesters.map((item) => { + const config = item.spec.config; + for (let i = 0; i < config.length; i++) { + if (config[i].name == 'base_url') { + return config[i].default?.toString() || ''; + } + } + return ''; + }), + ); + } + + async function getEmbeddingConfig( + id: string, + ): Promise { + const embeddingModel = await httpClient.getProviderEmbeddingModel(id); + + const fakeExtraArgs = []; + const extraArgs = embeddingModel.model.extra_args as Record; + for (const key in extraArgs) { + fakeExtraArgs.push(`${key}:${extraArgs[key]}`); + } + return { + name: embeddingModel.model.name, + model_provider: embeddingModel.model.requester, + url: embeddingModel.model.requester_config?.base_url, + api_key: embeddingModel.model.api_keys[0], + extra_args: fakeExtraArgs, + }; + } + + function handleFormSubmit(value: z.infer) { + const extraArgsObj: Record = {}; + value.extra_args?.forEach( + (arg: { key: string; type: string; value: string }) => { + if (arg.type === 'number') { + extraArgsObj[arg.key] = Number(arg.value); + } else if (arg.type === 'boolean') { + extraArgsObj[arg.key] = arg.value === 'true'; + } else { + extraArgsObj[arg.key] = arg.value; + } + }, + ); + + const embeddingModel: EmbeddingModel = { + uuid: editMode ? initEmbeddingId || '' : UUID.generate(), + name: value.name, + description: '', + requester: value.model_provider, + requester_config: { + base_url: value.url, + timeout: 120, + }, + extra_args: extraArgsObj, + api_keys: [value.api_key], + }; + + if (editMode) { + onSaveEdit(embeddingModel).then(() => { + form.reset(); + }); + } else { + onCreateEmbedding(embeddingModel).then(() => { + form.reset(); + }); + } + } + + async function onCreateEmbedding(embeddingModel: EmbeddingModel) { + try { + await httpClient.createProviderEmbeddingModel(embeddingModel); + onFormSubmit(); + toast.success(t('models.createSuccess')); + } catch (err) { + toast.error(t('models.createError') + (err as Error).message); + } + } + + async function onSaveEdit(embeddingModel: EmbeddingModel) { + try { + await httpClient.updateProviderEmbeddingModel( + initEmbeddingId || '', + embeddingModel, + ); + onFormSubmit(); + toast.success(t('models.saveSuccess')); + } catch (err) { + toast.error(t('models.saveError') + (err as Error).message); + } + } + + function deleteModel() { + if (initEmbeddingId) { + httpClient + .deleteProviderEmbeddingModel(initEmbeddingId) + .then(() => { + onEmbeddingDeleted(); + toast.success(t('models.deleteSuccess')); + }) + .catch((err) => { + toast.error(t('models.deleteError') + err.message); + }); + } + } + + function testEmbeddingModelInForm() { + setModelTesting(true); + httpClient + .testEmbeddingModel('_', { + uuid: '', + name: form.getValues('name'), + description: '', + requester: form.getValues('model_provider'), + requester_config: { + base_url: form.getValues('url'), + timeout: 120, + }, + api_keys: [form.getValues('api_key')], + }) + .then((res) => { + console.log(res); + toast.success(t('models.testSuccess')); + }) + .catch(() => { + toast.error(t('models.testError')); + }) + .finally(() => { + setModelTesting(false); + }); + } + + return ( +
+ + + + {t('common.confirmDelete')} + + + {t('models.deleteConfirmation')} + + + + + + + + +
+ +
+ ( + + + {t('models.modelName')} + * + + + + + + + {t('models.modelProviderDescription')} + + + )} + /> + + ( + + + {t('models.modelProvider')} + * + + + + + + + )} + /> + + ( + + + {t('models.requestURL')} + * + + + + + + + )} + /> + + ( + + + {t('models.apiKey')} + * + + + + + + + )} + /> + + + {t('models.extraParameters')} +
+ {extraArgs.map((arg, index) => ( +
+ + updateExtraArg(index, 'key', e.target.value) + } + /> + + + updateExtraArg(index, 'value', e.target.value) + } + /> + +
+ ))} + +
+ + {t('embedding.extraParametersDescription')} + + +
+
+ + {editMode && ( + + )} + + + + + + + +
+ +
+ ); +} diff --git a/web/src/app/home/models/component/llm-form/LLMForm.tsx b/web/src/app/home/models/component/llm-form/LLMForm.tsx index f483f183..73cc32fe 100644 --- a/web/src/app/home/models/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/models/component/llm-form/LLMForm.tsx @@ -1,6 +1,6 @@ -import { ICreateLLMField } from '@/app/home/models/ICreateLLMField'; +import { ICreateLLMField } from '@/app/home/models/component/ICreateLLMField'; import { useEffect, useState } from 'react'; -import { IChooseRequesterEntity } from '@/app/home/models/component/llm-form/ChooseRequesterEntity'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; import { UUID } from 'uuidjs'; @@ -197,7 +197,7 @@ export default function LLMForm({ }; async function initLLMModelFormComponent() { - const requesterNameList = await httpClient.getProviderRequesters(); + const requesterNameList = await httpClient.getProviderRequesters('llm'); setRequesterNameList( requesterNameList.requesters.map((item) => { return { @@ -596,7 +596,7 @@ export default function LLMForm({ - {t('models.extraParametersDescription')} + {t('llm.extraParametersDescription')} diff --git a/web/src/app/home/models/page.tsx b/web/src/app/home/models/page.tsx index 3ccec486..2f936753 100644 --- a/web/src/app/home/models/page.tsx +++ b/web/src/app/home/models/page.tsx @@ -8,6 +8,7 @@ import LLMForm from '@/app/home/models/component/llm-form/LLMForm'; import CreateCardComponent from '@/app/infra/basic-component/create-card-component/CreateCardComponent'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; import { Dialog, DialogContent, @@ -17,6 +18,9 @@ import { import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; import { i18nObj } from '@/i18n/I18nProvider'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; +import EmbeddingCard from '@/app/home/models/component/embedding-card/EmbeddingCard'; +import EmbeddingForm from '@/app/home/models/component/embedding-form/EmbeddingForm'; export default function LLMConfigPage() { const { t } = useTranslation(); @@ -24,13 +28,21 @@ export default function LLMConfigPage() { const [modalOpen, setModalOpen] = useState(false); const [isEditForm, setIsEditForm] = useState(false); const [nowSelectedLLM, setNowSelectedLLM] = useState(null); + const [embeddingCardList, setEmbeddingCardList] = useState( + [], + ); + const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false); + const [isEditEmbeddingForm, setIsEditEmbeddingForm] = useState(false); + const [nowSelectedEmbedding, setNowSelectedEmbedding] = + useState(null); useEffect(() => { getLLMModelList(); + getEmbeddingModelList(); }, []); async function getLLMModelList() { - const requesterNameListResp = await httpClient.getProviderRequesters(); + const requesterNameListResp = await httpClient.getProviderRequesters('llm'); const requesterNameList = requesterNameListResp.requesters.map((item) => { return { label: i18nObj(item.label), @@ -74,6 +86,55 @@ export default function LLMConfigPage() { setNowSelectedLLM(null); setModalOpen(true); } + function selectEmbedding(cardVO: EmbeddingCardVO) { + setIsEditEmbeddingForm(true); + setNowSelectedEmbedding(cardVO); + setEmbeddingModalOpen(true); + } + + function handleCreateEmbeddingModelClick() { + setIsEditEmbeddingForm(false); + setNowSelectedEmbedding(null); + setEmbeddingModalOpen(true); + } + async function getEmbeddingModelList() { + const requesterNameListResp = + await httpClient.getProviderRequesters('text-embedding'); + const requesterNameList = requesterNameListResp.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }); + + httpClient + .getProviderEmbeddingModels() + .then((resp) => { + const embeddingModelList: EmbeddingCardVO[] = resp.models.map( + (model: { + uuid: string; + requester: string; + name: string; + requester_config?: { base_url?: string }; + }) => { + return new EmbeddingCardVO({ + id: model.uuid, + iconURL: httpClient.getProviderRequesterIconURL(model.requester), + name: model.name, + providerLabel: + requesterNameList.find((item) => item.value === model.requester) + ?.label || model.requester.substring(0, 10), + baseURL: model.requester_config?.base_url || '', + }); + }, + ); + setEmbeddingCardList(embeddingModelList); + }) + .catch((err) => { + console.error('get Embedding model list error', err); + toast.error(t('embedding.getModelListError') + err.message); + }); + } return (
@@ -101,26 +162,108 @@ export default function LLMConfigPage() { /> -
- - {cardList.map((cardVO) => { - return ( -
{ - selectLLM(cardVO); - }} - > - + + + + + {isEditEmbeddingForm + ? t('embedding.editModel') + : t('embedding.createModel')} + + + { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + onFormCancel={() => { + setEmbeddingModalOpen(false); + }} + onEmbeddingDeleted={() => { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + /> + + + + +
+
+ + + {t('llm.llmModels')} + + + {t('embedding.embeddingModels')} + + +
+ +
+

{t('llm.description')}

- ); - })} -
+ + +
+

+ {t('embedding.description')} +

+
+
+
+ + +
+ + {cardList.map((cardVO) => { + return ( +
{ + selectLLM(cardVO); + }} + > + +
+ ); + })} +
+
+ + +
+ + {embeddingCardList.map((cardVO) => { + return ( +
{ + selectEmbedding(cardVO); + }} + > + +
+ ); + })} +
+
+
); } diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index d86a8be0..53ddf1dd 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -55,6 +55,29 @@ export interface LLMModel { // updated_at: string; } +export interface ApiRespProviderEmbeddingModels { + models: EmbeddingModel[]; +} + +export interface ApiRespProviderEmbeddingModel { + model: EmbeddingModel; +} + +export interface EmbeddingModel { + name: string; + description: string; + uuid: string; + requester: string; + requester_config: { + base_url: string; + timeout: number; + }; + extra_args?: object; + api_keys: string[]; + // created_at: string; + // updated_at: string; +} + export interface ApiRespPipelines { pipelines: Pipeline[]; } diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 5193703b..1fd335d9 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -10,6 +10,9 @@ import { ApiRespProviderLLMModels, ApiRespProviderLLMModel, LLMModel, + ApiRespProviderEmbeddingModels, + ApiRespProviderEmbeddingModel, + EmbeddingModel, ApiRespPipelines, Pipeline, ApiRespPlatformAdapters, @@ -226,8 +229,10 @@ class HttpClient { // real api request implementation // ============ Provider API ============ - public getProviderRequesters(): Promise { - return this.get('/api/v1/provider/requesters'); + public getProviderRequesters( + model_type: string, + ): Promise { + return this.get('/api/v1/provider/requesters', { type: model_type }); } public getProviderRequester(name: string): Promise { @@ -275,6 +280,39 @@ class HttpClient { return this.post(`/api/v1/provider/models/llm/${uuid}/test`, model); } + // ============ Provider Model Embedding ============ + public getProviderEmbeddingModels(): Promise { + return this.get('/api/v1/provider/models/embedding'); + } + + public getProviderEmbeddingModel( + uuid: string, + ): Promise { + return this.get(`/api/v1/provider/models/embedding/${uuid}`); + } + + public createProviderEmbeddingModel(model: EmbeddingModel): Promise { + return this.post('/api/v1/provider/models/embedding', model); + } + + public deleteProviderEmbeddingModel(uuid: string): Promise { + return this.delete(`/api/v1/provider/models/embedding/${uuid}`); + } + + public updateProviderEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.put(`/api/v1/provider/models/embedding/${uuid}`, model); + } + + public testEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.post(`/api/v1/provider/models/embedding/${uuid}/test`, model); + } + // ============ Pipeline API ============ public getGeneralPipelineMetadata(): Promise { // as designed, this method will be deprecated, and only for developer to check the prefered config schema diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 1975a521..d0df9841 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -86,14 +86,13 @@ const enUS = { string: 'String', number: 'Number', boolean: 'Boolean', - extraParametersDescription: - 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', selectModelProvider: 'Select Model Provider', modelProviderDescription: 'Please fill in the model name provided by the supplier', selectModel: 'Select Model', testSuccess: 'Test successful', testError: 'Test failed, please check your model configuration', + llmModels: 'LLM Models', }, bots: { title: 'Bots', @@ -259,6 +258,21 @@ const enUS = { 'Password reset failed, please check your email and recovery key', backToLogin: 'Back to Login', }, + embedding: { + description: 'Manage Embedding models for text vectorization', + createModel: 'Create Embedding Model', + editModel: 'Edit Embedding Model', + getModelListError: 'Failed to get Embedding model list: ', + embeddingModels: 'Embedding', + extraParametersDescription: + 'Will be attached to the request body, such as encoding_format, dimensions, etc.', + }, + llm: { + description: 'Manage LLM models for conversation generation', + llmModels: 'LLM', + extraParametersDescription: + 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', + }, }; export default enUS; diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 2ded8236..96acc0e6 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -87,13 +87,12 @@ const zhHans = { string: '字符串', number: '数字', boolean: '布尔值', - extraParametersDescription: - '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', selectModelProvider: '选择模型供应商', modelProviderDescription: '请填写供应商向您提供的模型名称', selectModel: '请选择模型', testSuccess: '测试成功', testError: '测试失败,请检查模型配置', + llmModels: '对话模型', }, bots: { title: '机器人', @@ -251,6 +250,21 @@ const zhHans = { resetFailed: '密码重置失败,请检查邮箱和恢复密钥是否正确', backToLogin: '返回登录', }, + embedding: { + description: '管理嵌入模型,用于向量化文本', + createModel: '创建嵌入模型', + editModel: '编辑嵌入模型', + getModelListError: '获取嵌入模型列表失败:', + embeddingModels: '嵌入模型', + extraParametersDescription: + '将在请求时附加到请求体中,如 encoding_format, dimensions 等', + }, + llm: { + llmModels: '对话模型', + description: '管理 LLM 模型,用于对话消息生成', + extraParametersDescription: + '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', + }, }; export default zhHans;