From 3124cc0fef889babf2a77a45c697b0c686b2c971 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 16 Mar 2025 23:16:06 +0800 Subject: [PATCH] feat: update requester config logic --- .../modelmgr/requesters/anthropicmsgs.py | 6 ++ .../modelmgr/requesters/bailianchatcmpl.py | 11 ++- pkg/provider/modelmgr/requesters/chatcmpl.py | 68 +++++++++++-------- .../modelmgr/requesters/deepseekchatcmpl.py | 10 ++- .../modelmgr/requesters/giteeaichatcmpl.py | 8 ++- .../modelmgr/requesters/lmstudiochatcmpl.py | 11 ++- .../modelmgr/requesters/moonshotchatcmpl.py | 10 ++- .../modelmgr/requesters/ollamachat.py | 22 +++--- .../requesters/siliconflowchatcmpl.py | 11 ++- .../modelmgr/requesters/volcarkchatcmpl.py | 11 ++- .../modelmgr/requesters/xaichatcmpl.py | 11 ++- .../modelmgr/requesters/zhipuaichatcmpl.py | 11 ++- 12 files changed, 104 insertions(+), 86 deletions(-) diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index b03e536d..937f5107 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -8,6 +8,7 @@ import base64 import anthropic import httpx +from ....core import app from .. import entities, errors, requester from .. import entities, errors @@ -22,6 +23,11 @@ class AnthropicMessages(requester.LLMAPIRequester): client: anthropic.AsyncAnthropic + default_config: dict[str, typing.Any] = { + 'base-url': 'https://api.anthropic.com/v1', + 'timeout': 120, + } + async def initialize(self): httpx_client = anthropic._base_client.AsyncHttpxClientWrapper( diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.py b/pkg/provider/modelmgr/requesters/bailianchatcmpl.py index 8f6b258c..ce718c4c 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing import openai from . import chatcmpl @@ -12,9 +13,7 @@ class BailianChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient - requester_cfg: dict - - def __init__(self, ap: app.Application): - self.ap = ap - - self.requester_cfg = self.ap.provider_cfg.data['requester']['bailian-chat-completions'] + default_config: dict[str, typing.Any] = { + 'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'timeout': 120, + } diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 7bf83377..7cf255c0 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -25,23 +25,20 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): client: openai.AsyncClient - requester_cfg: dict - - def __init__(self, ap: app.Application): - self.ap = ap - - self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions'] + default_config: dict[str, typing.Any] = { + "base-url": "https://api.openai.com/v1", + "timeout": 120, + } async def initialize(self): self.client = openai.AsyncClient( api_key="", - base_url=self.requester_cfg['base-url'], - timeout=self.requester_cfg['timeout'], + base_url=self.requester_cfg["base-url"], + timeout=self.requester_cfg["timeout"], http_client=httpx.AsyncClient( - trust_env=True, - timeout=self.requester_cfg['timeout'] - ) + trust_env=True, timeout=self.requester_cfg["timeout"] + ), ) async def _req( @@ -57,8 +54,8 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): chatcmpl_message = chat_completion.choices[0].message.dict() # 确保 role 字段存在且不为 None - if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: - chatcmpl_message['role'] = 'assistant' + if "role" not in chatcmpl_message or chatcmpl_message["role"] is None: + chatcmpl_message["role"] = "assistant" message = llm_entities.Message(**chatcmpl_message) @@ -70,11 +67,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): req_messages: list[dict], use_model: entities.LLMModelInfo, use_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取 ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() - args = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + args = self.requester_cfg["args"].copy() + args["model"] = ( + use_model.name if use_model.model_name is None else use_model.model_name + ) if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) @@ -87,12 +87,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): # 检查vision for msg in messages: - if 'content' in msg and isinstance(msg["content"], list): + if "content" in msg and isinstance(msg["content"], list): for me in msg["content"]: if me["type"] == "image_base64": - me["image_url"] = { - "url": me["image_base64"] - } + me["image_url"] = {"url": me["image_base64"]} me["type"] = "image_url" del me["image_base64"] @@ -105,13 +103,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): message = await self._make_msg(resp) return message - + async def call( self, query: core_entities.Query, model: entities.LLMModelInfo, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: @@ -119,25 +118,34 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): content = msg_dict.get("content") if isinstance(content, list): # 检查 content 列表中是否每个部分都是文本 - if all(isinstance(part, dict) and part.get("type") == "text" for part in content): + if all( + isinstance(part, dict) and part.get("type") == "text" + for part in content + ): # 将所有文本部分合并为一个字符串 msg_dict["content"] = "\n".join(part["text"] for part in content) req_messages.append(msg_dict) try: - return await self._closure(query=query, req_messages=req_messages, use_model=model, use_funcs=funcs) + return await self._closure( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + extra_args=extra_args, + ) except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') + raise errors.RequesterError("请求超时") except openai.BadRequestError as e: - if 'context_length_exceeded' in e.message: - raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') + if "context_length_exceeded" in e.message: + raise errors.RequesterError(f"上文过长,请重置会话: {e.message}") else: - raise errors.RequesterError(f'请求参数错误: {e.message}') + raise errors.RequesterError(f"请求参数错误: {e.message}") except openai.AuthenticationError as e: - raise errors.RequesterError(f'无效的 api-key: {e.message}') + raise errors.RequesterError(f"无效的 api-key: {e.message}") except openai.NotFoundError as e: - raise errors.RequesterError(f'请求路径错误: {e.message}') + raise errors.RequesterError(f"请求路径错误: {e.message}") except openai.RateLimitError as e: - raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') + raise errors.RequesterError(f"请求过于频繁或余额不足: {e.message}") except openai.APIError as e: - raise errors.RequesterError(f'请求错误: {e.message}') + raise errors.RequesterError(f"请求错误: {e.message}") diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index f6453a19..e82d0d81 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing + from . import chatcmpl from .. import entities, errors, requester from ....core import entities as core_entities, app @@ -10,9 +12,10 @@ from ...tools import entities as tools_entities class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): """Deepseek ChatCompletion API 请求器""" - def __init__(self, ap: app.Application): - self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions'] - self.ap = ap + default_config: dict[str, typing.Any] = { + 'base-url': 'https://api.deepseek.com', + 'timeout': 120, + } async def _closure( self, @@ -20,6 +23,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): req_messages: list[dict], use_model: entities.LLMModelInfo, use_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py index 4beb6ba8..dec0b8d1 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py @@ -17,9 +17,10 @@ from .. import entities as modelmgr_entities class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): """Gitee AI ChatCompletions API 请求器""" - def __init__(self, ap: app.Application): - self.ap = ap - self.requester_cfg = ap.provider_cfg.data['requester']['gitee-ai-chat-completions'].copy() + default_config: dict[str, typing.Any] = { + 'base-url': 'https://ai.gitee.com/v1', + 'timeout': 120, + } async def _closure( self, @@ -27,6 +28,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): req_messages: list[dict], use_model: entities.LLMModelInfo, use_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py index d2a9bcb7..c00be372 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing import openai from . import chatcmpl @@ -12,9 +13,7 @@ class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient - requester_cfg: dict - - def __init__(self, ap: app.Application): - self.ap = ap - - self.requester_cfg = self.ap.provider_cfg.data['requester']['lmstudio-chat-completions'] + default_config: dict[str, typing.Any] = { + 'base-url': 'http://127.0.0.1:1234/v1', + 'timeout': 120, + } diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py index 2e94fd04..3cbe8837 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing + from ....core import app from . import chatcmpl @@ -12,9 +14,10 @@ from ...tools import entities as tools_entities class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): """Moonshot ChatCompletion API 请求器""" - def __init__(self, ap: app.Application): - self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions'] - self.ap = ap + default_config: dict[str, typing.Any] = { + 'base-url': 'https://api.moonshot.cn/v1', + 'timeout': 120, + } async def _closure( self, @@ -22,6 +25,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): req_messages: list[dict], use_model: entities.LLMModelInfo, use_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index 0ac2915f..fa99cfe5 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -23,17 +23,16 @@ REQUESTER_NAME: str = "ollama-chat" class OllamaChatCompletions(requester.LLMAPIRequester): """Ollama平台 ChatCompletion API请求器""" client: ollama.AsyncClient - request_cfg: dict - def __init__(self, ap: app.Application): - super().__init__(ap) - self.ap = ap - self.request_cfg = self.ap.provider_cfg.data['requester'][REQUESTER_NAME] + default_config: dict[str, typing.Any] = { + 'base-url': 'http://127.0.0.1:11434', + 'timeout': 120, + } async def initialize(self): - os.environ['OLLAMA_HOST'] = self.request_cfg['base-url'] + os.environ['OLLAMA_HOST'] = self.requester_cfg['base-url'] self.client = ollama.AsyncClient( - timeout=self.request_cfg['timeout'] + timeout=self.requester_cfg['timeout'] ) async def _req(self, @@ -44,9 +43,9 @@ class OllamaChatCompletions(requester.LLMAPIRequester): ) async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo, - user_funcs: list[tools_entities.LLMFunction] = None) -> ( - llm_entities.Message): - args: Any = self.request_cfg['args'].copy() + user_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}) -> llm_entities.Message: + args: Any = self.requester_cfg['args'].copy() args["model"] = use_model.name if use_model.model_name is None else use_model.model_name messages: list[dict] = req_messages.copy() @@ -113,6 +112,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester): model: entities.LLMModelInfo, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: req_messages: list = [] for m in messages: @@ -123,6 +123,6 @@ class OllamaChatCompletions(requester.LLMAPIRequester): msg_dict["content"] = "\n".join(part["text"] for part in content) req_messages.append(msg_dict) try: - return await self._closure(query, req_messages, model, funcs) + return await self._closure(query, req_messages, model, funcs, extra_args) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py index c763556f..a990f809 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing import openai from . import chatcmpl @@ -12,9 +13,7 @@ class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient - requester_cfg: dict - - def __init__(self, ap: app.Application): - self.ap = ap - - self.requester_cfg = self.ap.provider_cfg.data['requester']['siliconflow-chat-completions'] + default_config: dict[str, typing.Any] = { + 'base-url': 'https://api.siliconflow.cn/v1', + 'timeout': 120, + } diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py index f2a58789..fbf88826 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing import openai from . import chatcmpl @@ -12,9 +13,7 @@ class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient - requester_cfg: dict - - def __init__(self, ap: app.Application): - self.ap = ap - - self.requester_cfg = self.ap.provider_cfg.data['requester']['volcark-chat-completions'] + default_config: dict[str, typing.Any] = { + 'base-url': 'https://ark.cn-beijing.volces.com/api/v3', + 'timeout': 120, + } diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.py b/pkg/provider/modelmgr/requesters/xaichatcmpl.py index 217b142f..47c2939a 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing import openai from . import chatcmpl @@ -12,9 +13,7 @@ class XaiChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient - requester_cfg: dict - - def __init__(self, ap: app.Application): - self.ap = ap - - self.requester_cfg = self.ap.provider_cfg.data['requester']['xai-chat-completions'] + default_config: dict[str, typing.Any] = { + 'base-url': 'https://api.x.ai/v1', + 'timeout': 120, + } diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py index 18edd36d..1e24a5ef 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing import openai from ....core import app @@ -12,9 +13,7 @@ class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient - requester_cfg: dict - - def __init__(self, ap: app.Application): - self.ap = ap - - self.requester_cfg = self.ap.provider_cfg.data['requester']['zhipuai-chat-completions'] + default_config: dict[str, typing.Any] = { + 'base-url': 'https://open.bigmodel.cn/api/paas/v4', + 'timeout': 120, + }