feat: update requester config logic

This commit is contained in:
Junyan Qin
2025-03-16 23:16:06 +08:00
parent 5c584ee60d
commit 3124cc0fef
12 changed files with 104 additions and 86 deletions
+38 -30
View File
@@ -25,23 +25,20 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
client: openai.AsyncClient
requester_cfg: dict
def __init__(self, ap: app.Application):
self.ap = ap
self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions']
default_config: dict[str, typing.Any] = {
"base-url": "https://api.openai.com/v1",
"timeout": 120,
}
async def initialize(self):
self.client = openai.AsyncClient(
api_key="",
base_url=self.requester_cfg['base-url'],
timeout=self.requester_cfg['timeout'],
base_url=self.requester_cfg["base-url"],
timeout=self.requester_cfg["timeout"],
http_client=httpx.AsyncClient(
trust_env=True,
timeout=self.requester_cfg['timeout']
)
trust_env=True, timeout=self.requester_cfg["timeout"]
),
)
async def _req(
@@ -57,8 +54,8 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
chatcmpl_message = chat_completion.choices[0].message.dict()
# 确保 role 字段存在且不为 None
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
chatcmpl_message['role'] = 'assistant'
if "role" not in chatcmpl_message or chatcmpl_message["role"] is None:
chatcmpl_message["role"] = "assistant"
message = llm_entities.Message(**chatcmpl_message)
@@ -70,11 +67,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = self.requester_cfg["args"].copy()
args["model"] = (
use_model.name if use_model.model_name is None else use_model.model_name
)
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
@@ -87,12 +87,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
# 检查vision
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
if "content" in msg and isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_base64":
me["image_url"] = {
"url": me["image_base64"]
}
me["image_url"] = {"url": me["image_base64"]}
me["type"] = "image_url"
del me["image_base64"]
@@ -105,13 +103,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
message = await self._make_msg(resp)
return message
async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
for m in messages:
@@ -119,25 +118,34 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
content = msg_dict.get("content")
if isinstance(content, list):
# 检查 content 列表中是否每个部分都是文本
if all(isinstance(part, dict) and part.get("type") == "text" for part in content):
if all(
isinstance(part, dict) and part.get("type") == "text"
for part in content
):
# 将所有文本部分合并为一个字符串
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(query=query, req_messages=req_messages, use_model=model, use_funcs=funcs)
return await self._closure(
query=query,
req_messages=req_messages,
use_model=model,
use_funcs=funcs,
extra_args=extra_args,
)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
raise errors.RequesterError("请求超时")
except openai.BadRequestError as e:
if 'context_length_exceeded' in e.message:
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
if "context_length_exceeded" in e.message:
raise errors.RequesterError(f"上文过长,请重置会话: {e.message}")
else:
raise errors.RequesterError(f'请求参数错误: {e.message}')
raise errors.RequesterError(f"请求参数错误: {e.message}")
except openai.AuthenticationError as e:
raise errors.RequesterError(f'无效的 api-key: {e.message}')
raise errors.RequesterError(f"无效的 api-key: {e.message}")
except openai.NotFoundError as e:
raise errors.RequesterError(f'请求路径错误: {e.message}')
raise errors.RequesterError(f"请求路径错误: {e.message}")
except openai.RateLimitError as e:
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
raise errors.RequesterError(f"请求过于频繁或余额不足: {e.message}")
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')
raise errors.RequesterError(f"请求错误: {e.message}")