feat: add embeddings model management (#1461)

* feat: add embeddings model management backend support

Co-Authored-By: Junyan Qin <Chin> <rockchinq@gmail.com>

* feat: add embeddings model management frontend support

Co-Authored-By: Junyan Qin <Chin> <rockchinq@gmail.com>

* chore: revert HttpClient URL to production setting

Co-Authored-By: Junyan Qin <Chin> <rockchinq@gmail.com>

* refactor: integrate embeddings models into models page with tabs

Co-Authored-By: Junyan Qin <Chin> <rockchinq@gmail.com>

* perf: move files

* perf: remove `s`

* feat: allow requester to declare supported types in manifest

* feat(embedding): delete dimension and encoding format

* feat: add extra_args for embedding moels

* perf: i18n ref

* fix: linter err

* fix: lint err

* fix: linter err

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Junyan Qin <Chin> <rockchinq@gmail.com>
This commit is contained in:
devin-ai-integration[bot]
2025-05-21 12:42:39 +08:00
committed by Junyan Qin
parent a01706d163
commit d2b93b3296
43 changed files with 1370 additions and 64 deletions
+1 -1
View File
@@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel):
token_mgr: token.TokenManager
requester: requester.LLMAPIRequester
requester: requester.ProviderAPIRequester
tool_call_supported: typing.Optional[bool] = False
+75 -14
View File
@@ -18,7 +18,7 @@ class ModelManager:
model_list: list[entities.LLMModelInfo] # deprecated
requesters: dict[str, requester.LLMAPIRequester] # deprecated
requesters: dict[str, requester.ProviderAPIRequester] # deprecated
token_mgrs: dict[str, token.TokenManager] # deprecated
@@ -28,9 +28,11 @@ class ModelManager:
llm_models: list[requester.RuntimeLLMModel]
embedding_models: list[requester.RuntimeEmbeddingModel]
requester_components: list[engine.Component]
requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache
requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache
def __init__(self, ap: app.Application):
self.ap = ap
@@ -38,6 +40,7 @@ class ModelManager:
self.requesters = {}
self.token_mgrs = {}
self.llm_models = []
self.embedding_models = []
self.requester_components = []
self.requester_dict = {}
@@ -45,7 +48,7 @@ class ModelManager:
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
# forge requester class dict
requester_dict: dict[str, type[requester.LLMAPIRequester]] = {}
requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {}
for component in self.requester_components:
requester_dict[component.metadata.name] = component.get_python_component_class()
@@ -58,13 +61,11 @@ class ModelManager:
self.ap.logger.info('Loading models from db...')
self.llm_models = []
self.embedding_models = []
# llm models
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
llm_models = result.all()
# load models
for llm_model in llm_models:
try:
await self.load_llm_model(llm_model)
@@ -73,11 +74,17 @@ class ModelManager:
except Exception as e:
self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}')
# embedding models
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel))
embedding_models = result.all()
for embedding_model in embedding_models:
await self.load_embedding_model(embedding_model)
async def init_runtime_llm_model(
self,
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
):
"""初始化运行时模型"""
"""初始化运行时 LLM 模型"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.LLMModel(**model_info._mapping)
elif isinstance(model_info, dict):
@@ -101,14 +108,47 @@ class ModelManager:
return runtime_llm_model
async def init_runtime_embedding_model(
self,
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict,
):
"""初始化运行时 Embedding 模型"""
if isinstance(model_info, sqlalchemy.Row):
model_info = persistence_model.EmbeddingModel(**model_info._mapping)
elif isinstance(model_info, dict):
model_info = persistence_model.EmbeddingModel(**model_info)
requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config)
await requester_inst.initialize()
runtime_embedding_model = requester.RuntimeEmbeddingModel(
model_entity=model_info,
token_mgr=token.TokenManager(
name=model_info.uuid,
tokens=model_info.api_keys,
),
requester=requester_inst,
)
return runtime_embedding_model
async def load_llm_model(
self,
model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict,
):
"""加载模型"""
"""加载 LLM 模型"""
runtime_llm_model = await self.init_runtime_llm_model(model_info)
self.llm_models.append(runtime_llm_model)
async def load_embedding_model(
self,
model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict,
):
"""加载 Embedding 模型"""
runtime_embedding_model = await self.init_runtime_embedding_model(model_info)
self.embedding_models.append(runtime_embedding_model)
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated
"""通过名称获取模型"""
for model in self.model_list:
@@ -116,23 +156,44 @@ class ModelManager:
return model
raise ValueError(f'无法确定模型 {name} 的信息')
async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo:
"""通过uuid获取模型"""
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
"""通过uuid获取 LLM 模型"""
for model in self.llm_models:
if model.model_entity.uuid == uuid:
return model
raise ValueError(f'model {uuid} not found')
raise ValueError(f'LLM model {uuid} not found')
async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel:
"""通过uuid获取 Embedding 模型"""
for model in self.embedding_models:
if model.model_entity.uuid == uuid:
return model
raise ValueError(f'Embedding model {uuid} not found')
async def remove_llm_model(self, model_uuid: str):
"""移除模型"""
"""移除 LLM 模型"""
for model in self.llm_models:
if model.model_entity.uuid == model_uuid:
self.llm_models.remove(model)
return
def get_available_requesters_info(self) -> list[dict]:
async def remove_embedding_model(self, model_uuid: str):
"""移除 Embedding 模型"""
for model in self.embedding_models:
if model.model_entity.uuid == model_uuid:
self.embedding_models.remove(model)
return
def get_available_requesters_info(self, model_type: str) -> list[dict]:
"""获取所有可用的请求器"""
return [component.to_plain_dict() for component in self.requester_components]
if model_type != '':
return [
component.to_plain_dict()
for component in self.requester_components
if model_type in component.spec['support_type']
]
else:
return [component.to_plain_dict() for component in self.requester_components]
def get_available_requester_info_by_name(self, name: str) -> dict | None:
"""通过名称获取请求器信息"""
+47 -4
View File
@@ -20,22 +20,45 @@ class RuntimeLLMModel:
token_mgr: token.TokenManager
"""api key管理器"""
requester: LLMAPIRequester
requester: ProviderAPIRequester
"""请求器实例"""
def __init__(
self,
model_entity: persistence_model.LLMModel,
token_mgr: token.TokenManager,
requester: LLMAPIRequester,
requester: ProviderAPIRequester,
):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器"""
class RuntimeEmbeddingModel:
"""运行时 Embedding 模型"""
model_entity: persistence_model.EmbeddingModel
"""模型数据"""
token_mgr: token.TokenManager
"""api key管理器"""
requester: ProviderAPIRequester
"""请求器实例"""
def __init__(
self,
model_entity: persistence_model.EmbeddingModel,
token_mgr: token.TokenManager,
requester: ProviderAPIRequester,
):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
class ProviderAPIRequester(metaclass=abc.ABCMeta):
"""Provider API请求器"""
name: str = None
@@ -74,3 +97,23 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
llm_entities.Message: 返回消息对象
"""
pass
async def invoke_embedding(
self,
query: core_entities.Query,
model: RuntimeEmbeddingModel,
input_text: str,
extra_args: dict[str, typing.Any] = {},
) -> list[float]:
"""调用 Embedding API
Args:
query (core_entities.Query): 请求上下文
model (RuntimeEmbeddingModel): 使用的模型信息
input_text (str): 输入文本
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
Returns:
list[float]: 返回的 embedding 向量
"""
pass
@@ -15,7 +15,7 @@ from ...tools import entities as tools_entities
from ....utils import image
class AnthropicMessages(requester.LLMAPIRequester):
class AnthropicMessages(requester.ProviderAPIRequester):
"""Anthropic Messages API 请求器"""
client: anthropic.AsyncAnthropic
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./anthropicmsgs.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./bailianchatcmpl.py
+37 -1
View File
@@ -13,7 +13,7 @@ from ... import entities as llm_entities
from ...tools import entities as tools_entities
class OpenAIChatCompletions(requester.LLMAPIRequester):
class OpenAIChatCompletions(requester.ProviderAPIRequester):
"""OpenAI ChatCompletion API 请求器"""
client: openai.AsyncClient
@@ -141,3 +141,39 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')
async def invoke_embedding(
self,
query: core_entities.Query,
model: requester.RuntimeEmbeddingModel,
input_text: str,
extra_args: dict[str, typing.Any] = {},
) -> list[float]:
"""调用 Embedding API"""
self.client.api_key = model.token_mgr.get_token()
args = {
'model': model.model_entity.name,
'input': input_text,
}
if model.model_entity.extra_args:
args.update(model.model_entity.extra_args)
args.update(extra_args)
try:
resp = await self.client.embeddings.create(**args)
return resp.data[0].embedding
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:
raise errors.RequesterError(f'请求参数错误: {e.message}')
except openai.AuthenticationError as e:
raise errors.RequesterError(f'无效的 api-key: {e.message}')
except openai.NotFoundError as e:
raise errors.RequesterError(f'请求路径错误: {e.message}')
except openai.RateLimitError as e:
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')
@@ -22,6 +22,9 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
- text-embedding
execution:
python:
path: ./chatcmpl.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./deepseekchatcmpl.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./geminichatcmpl.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./giteeaichatcmpl.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./lmstudiochatcmpl.py
@@ -14,7 +14,7 @@ from ... import entities as llm_entities
from ...tools import entities as tools_entities
class ModelScopeChatCompletions(requester.LLMAPIRequester):
class ModelScopeChatCompletions(requester.ProviderAPIRequester):
"""ModelScope ChatCompletion API 请求器"""
client: openai.AsyncClient
@@ -29,6 +29,8 @@ spec:
type: int
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./modelscopechatcmpl.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./moonshotchatcmpl.py
@@ -17,7 +17,7 @@ from ....core import entities as core_entities
REQUESTER_NAME: str = 'ollama-chat'
class OllamaChatCompletions(requester.LLMAPIRequester):
class OllamaChatCompletions(requester.ProviderAPIRequester):
"""Ollama平台 ChatCompletion API请求器"""
client: ollama.AsyncClient
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./ollamachat.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./openrouterchatcmpl.py
@@ -29,6 +29,8 @@ spec:
type: int
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./ppiochatcmpl.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./siliconflowchatcmpl.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./volcarkchatcmpl.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./xaichatcmpl.py
@@ -22,6 +22,8 @@ spec:
type: integer
required: true
default: 120
support_type:
- llm
execution:
python:
path: ./zhipuaichatcmpl.py