feat: update requester config logic

This commit is contained in:
Junyan Qin
2025-03-16 23:16:06 +08:00
parent 5c584ee60d
commit 3124cc0fef
12 changed files with 104 additions and 86 deletions
@@ -8,6 +8,7 @@ import base64
import anthropic import anthropic
import httpx import httpx
from ....core import app
from .. import entities, errors, requester from .. import entities, errors, requester
from .. import entities, errors from .. import entities, errors
@@ -22,6 +23,11 @@ class AnthropicMessages(requester.LLMAPIRequester):
client: anthropic.AsyncAnthropic client: anthropic.AsyncAnthropic
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.anthropic.com/v1',
'timeout': 120,
}
async def initialize(self): async def initialize(self):
httpx_client = anthropic._base_client.AsyncHttpxClientWrapper( httpx_client = anthropic._base_client.AsyncHttpxClientWrapper(
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing
import openai import openai
from . import chatcmpl from . import chatcmpl
@@ -12,9 +13,7 @@ class BailianChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient client: openai.AsyncClient
requester_cfg: dict default_config: dict[str, typing.Any] = {
'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
def __init__(self, ap: app.Application): 'timeout': 120,
self.ap = ap }
self.requester_cfg = self.ap.provider_cfg.data['requester']['bailian-chat-completions']
+38 -30
View File
@@ -25,23 +25,20 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
client: openai.AsyncClient client: openai.AsyncClient
requester_cfg: dict default_config: dict[str, typing.Any] = {
"base-url": "https://api.openai.com/v1",
def __init__(self, ap: app.Application): "timeout": 120,
self.ap = ap }
self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions']
async def initialize(self): async def initialize(self):
self.client = openai.AsyncClient( self.client = openai.AsyncClient(
api_key="", api_key="",
base_url=self.requester_cfg['base-url'], base_url=self.requester_cfg["base-url"],
timeout=self.requester_cfg['timeout'], timeout=self.requester_cfg["timeout"],
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
trust_env=True, trust_env=True, timeout=self.requester_cfg["timeout"]
timeout=self.requester_cfg['timeout'] ),
)
) )
async def _req( async def _req(
@@ -57,8 +54,8 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
chatcmpl_message = chat_completion.choices[0].message.dict() chatcmpl_message = chat_completion.choices[0].message.dict()
# 确保 role 字段存在且不为 None # 确保 role 字段存在且不为 None
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: if "role" not in chatcmpl_message or chatcmpl_message["role"] is None:
chatcmpl_message['role'] = 'assistant' chatcmpl_message["role"] = "assistant"
message = llm_entities.Message(**chatcmpl_message) message = llm_entities.Message(**chatcmpl_message)
@@ -70,11 +67,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
req_messages: list[dict], req_messages: list[dict],
use_model: entities.LLMModelInfo, use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取
) -> llm_entities.Message: ) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy() args = self.requester_cfg["args"].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name args["model"] = (
use_model.name if use_model.model_name is None else use_model.model_name
)
if use_funcs: if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
@@ -87,12 +87,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
# 检查vision # 检查vision
for msg in messages: for msg in messages:
if 'content' in msg and isinstance(msg["content"], list): if "content" in msg and isinstance(msg["content"], list):
for me in msg["content"]: for me in msg["content"]:
if me["type"] == "image_base64": if me["type"] == "image_base64":
me["image_url"] = { me["image_url"] = {"url": me["image_base64"]}
"url": me["image_base64"]
}
me["type"] = "image_url" me["type"] = "image_url"
del me["image_base64"] del me["image_base64"]
@@ -105,13 +103,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
message = await self._make_msg(resp) message = await self._make_msg(resp)
return message return message
async def call( async def call(
self, self,
query: core_entities.Query, query: core_entities.Query,
model: entities.LLMModelInfo, model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None, funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message: ) -> llm_entities.Message:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
for m in messages: for m in messages:
@@ -119,25 +118,34 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
content = msg_dict.get("content") content = msg_dict.get("content")
if isinstance(content, list): if isinstance(content, list):
# 检查 content 列表中是否每个部分都是文本 # 检查 content 列表中是否每个部分都是文本
if all(isinstance(part, dict) and part.get("type") == "text" for part in content): if all(
isinstance(part, dict) and part.get("type") == "text"
for part in content
):
# 将所有文本部分合并为一个字符串 # 将所有文本部分合并为一个字符串
msg_dict["content"] = "\n".join(part["text"] for part in content) msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict) req_messages.append(msg_dict)
try: try:
return await self._closure(query=query, req_messages=req_messages, use_model=model, use_funcs=funcs) return await self._closure(
query=query,
req_messages=req_messages,
use_model=model,
use_funcs=funcs,
extra_args=extra_args,
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise errors.RequesterError('请求超时') raise errors.RequesterError("请求超时")
except openai.BadRequestError as e: except openai.BadRequestError as e:
if 'context_length_exceeded' in e.message: if "context_length_exceeded" in e.message:
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') raise errors.RequesterError(f"上文过长,请重置会话: {e.message}")
else: else:
raise errors.RequesterError(f'请求参数错误: {e.message}') raise errors.RequesterError(f"请求参数错误: {e.message}")
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
raise errors.RequesterError(f'无效的 api-key: {e.message}') raise errors.RequesterError(f"无效的 api-key: {e.message}")
except openai.NotFoundError as e: except openai.NotFoundError as e:
raise errors.RequesterError(f'请求路径错误: {e.message}') raise errors.RequesterError(f"请求路径错误: {e.message}")
except openai.RateLimitError as e: except openai.RateLimitError as e:
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') raise errors.RequesterError(f"请求过于频繁或余额不足: {e.message}")
except openai.APIError as e: except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}') raise errors.RequesterError(f"请求错误: {e.message}")
@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import typing
from . import chatcmpl from . import chatcmpl
from .. import entities, errors, requester from .. import entities, errors, requester
from ....core import entities as core_entities, app from ....core import entities as core_entities, app
@@ -10,9 +12,10 @@ from ...tools import entities as tools_entities
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Deepseek ChatCompletion API 请求器""" """Deepseek ChatCompletion API 请求器"""
def __init__(self, ap: app.Application): default_config: dict[str, typing.Any] = {
self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions'] 'base-url': 'https://api.deepseek.com',
self.ap = ap 'timeout': 120,
}
async def _closure( async def _closure(
self, self,
@@ -20,6 +23,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
req_messages: list[dict], req_messages: list[dict],
use_model: entities.LLMModelInfo, use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message: ) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
@@ -17,9 +17,10 @@ from .. import entities as modelmgr_entities
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Gitee AI ChatCompletions API 请求器""" """Gitee AI ChatCompletions API 请求器"""
def __init__(self, ap: app.Application): default_config: dict[str, typing.Any] = {
self.ap = ap 'base-url': 'https://ai.gitee.com/v1',
self.requester_cfg = ap.provider_cfg.data['requester']['gitee-ai-chat-completions'].copy() 'timeout': 120,
}
async def _closure( async def _closure(
self, self,
@@ -27,6 +28,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
req_messages: list[dict], req_messages: list[dict],
use_model: entities.LLMModelInfo, use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message: ) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing
import openai import openai
from . import chatcmpl from . import chatcmpl
@@ -12,9 +13,7 @@ class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient client: openai.AsyncClient
requester_cfg: dict default_config: dict[str, typing.Any] = {
'base-url': 'http://127.0.0.1:1234/v1',
def __init__(self, ap: app.Application): 'timeout': 120,
self.ap = ap }
self.requester_cfg = self.ap.provider_cfg.data['requester']['lmstudio-chat-completions']
@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import typing
from ....core import app from ....core import app
from . import chatcmpl from . import chatcmpl
@@ -12,9 +14,10 @@ from ...tools import entities as tools_entities
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Moonshot ChatCompletion API 请求器""" """Moonshot ChatCompletion API 请求器"""
def __init__(self, ap: app.Application): default_config: dict[str, typing.Any] = {
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions'] 'base-url': 'https://api.moonshot.cn/v1',
self.ap = ap 'timeout': 120,
}
async def _closure( async def _closure(
self, self,
@@ -22,6 +25,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
req_messages: list[dict], req_messages: list[dict],
use_model: entities.LLMModelInfo, use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message: ) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
+11 -11
View File
@@ -23,17 +23,16 @@ REQUESTER_NAME: str = "ollama-chat"
class OllamaChatCompletions(requester.LLMAPIRequester): class OllamaChatCompletions(requester.LLMAPIRequester):
"""Ollama平台 ChatCompletion API请求器""" """Ollama平台 ChatCompletion API请求器"""
client: ollama.AsyncClient client: ollama.AsyncClient
request_cfg: dict
def __init__(self, ap: app.Application): default_config: dict[str, typing.Any] = {
super().__init__(ap) 'base-url': 'http://127.0.0.1:11434',
self.ap = ap 'timeout': 120,
self.request_cfg = self.ap.provider_cfg.data['requester'][REQUESTER_NAME] }
async def initialize(self): async def initialize(self):
os.environ['OLLAMA_HOST'] = self.request_cfg['base-url'] os.environ['OLLAMA_HOST'] = self.requester_cfg['base-url']
self.client = ollama.AsyncClient( self.client = ollama.AsyncClient(
timeout=self.request_cfg['timeout'] timeout=self.requester_cfg['timeout']
) )
async def _req(self, async def _req(self,
@@ -44,9 +43,9 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
) )
async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo, async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo,
user_funcs: list[tools_entities.LLMFunction] = None) -> ( user_funcs: list[tools_entities.LLMFunction] = None,
llm_entities.Message): extra_args: dict[str, typing.Any] = {}) -> llm_entities.Message:
args: Any = self.request_cfg['args'].copy() args: Any = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
messages: list[dict] = req_messages.copy() messages: list[dict] = req_messages.copy()
@@ -113,6 +112,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
model: entities.LLMModelInfo, model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None, funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message: ) -> llm_entities.Message:
req_messages: list = [] req_messages: list = []
for m in messages: for m in messages:
@@ -123,6 +123,6 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
msg_dict["content"] = "\n".join(part["text"] for part in content) msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict) req_messages.append(msg_dict)
try: try:
return await self._closure(query, req_messages, model, funcs) return await self._closure(query, req_messages, model, funcs, extra_args)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise errors.RequesterError('请求超时') raise errors.RequesterError('请求超时')
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing
import openai import openai
from . import chatcmpl from . import chatcmpl
@@ -12,9 +13,7 @@ class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient client: openai.AsyncClient
requester_cfg: dict default_config: dict[str, typing.Any] = {
'base-url': 'https://api.siliconflow.cn/v1',
def __init__(self, ap: app.Application): 'timeout': 120,
self.ap = ap }
self.requester_cfg = self.ap.provider_cfg.data['requester']['siliconflow-chat-completions']
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing
import openai import openai
from . import chatcmpl from . import chatcmpl
@@ -12,9 +13,7 @@ class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient client: openai.AsyncClient
requester_cfg: dict default_config: dict[str, typing.Any] = {
'base-url': 'https://ark.cn-beijing.volces.com/api/v3',
def __init__(self, ap: app.Application): 'timeout': 120,
self.ap = ap }
self.requester_cfg = self.ap.provider_cfg.data['requester']['volcark-chat-completions']
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing
import openai import openai
from . import chatcmpl from . import chatcmpl
@@ -12,9 +13,7 @@ class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient client: openai.AsyncClient
requester_cfg: dict default_config: dict[str, typing.Any] = {
'base-url': 'https://api.x.ai/v1',
def __init__(self, ap: app.Application): 'timeout': 120,
self.ap = ap }
self.requester_cfg = self.ap.provider_cfg.data['requester']['xai-chat-completions']
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing
import openai import openai
from ....core import app from ....core import app
@@ -12,9 +13,7 @@ class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient client: openai.AsyncClient
requester_cfg: dict default_config: dict[str, typing.Any] = {
'base-url': 'https://open.bigmodel.cn/api/paas/v4',
def __init__(self, ap: app.Application): 'timeout': 120,
self.ap = ap }
self.requester_cfg = self.ap.provider_cfg.data['requester']['zhipuai-chat-completions']