feat(rag): make embedding and retrieving available

This commit is contained in:
Junyan Qin
2025-07-16 21:17:18 +08:00
parent f731115805
commit 2f2db4d445
20 changed files with 180 additions and 368 deletions

View File

@@ -101,18 +101,18 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
async def invoke_embedding(
self,
model: RuntimeEmbeddingModel,
input_text: str,
input_text: list[str],
extra_args: dict[str, typing.Any] = {},
) -> list[float]:
) -> list[list[float]]:
"""调用 Embedding API
Args:
query (core_entities.Query): 请求上下文
model (RuntimeEmbeddingModel): 使用的模型信息
input_text (str): 输入文本
input_text (list[str]): 输入文本
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
Returns:
list[float]: 返回的 embedding 向量
list[list[float]]: 返回的 embedding 向量
"""
pass

View File

@@ -145,9 +145,9 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
async def invoke_embedding(
self,
model: requester.RuntimeEmbeddingModel,
input_text: str,
input_text: list[str],
extra_args: dict[str, typing.Any] = {},
) -> list[float]:
) -> list[list[float]]:
"""调用 Embedding API"""
self.client.api_key = model.token_mgr.get_token()
@@ -163,7 +163,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
try:
resp = await self.client.embeddings.create(**args)
return resp.data[0].embedding
return [d.embedding for d in resp.data]
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e: