feat: 解藕chat的处理器和请求器 (#772)

This commit is contained in:
RockChinQ
2024-05-14 22:20:31 +08:00
parent 972d3c18af
commit 527ad81d38
17 changed files with 205 additions and 137 deletions
+11 -6
View File
@@ -21,6 +21,16 @@ class ToolCall(pydantic.BaseModel):
function: FunctionCall
class Content(pydantic.BaseModel):
type: str
"""内容类型"""
text: typing.Optional[str] = None
image_url: typing.Optional[str] = None
class Message(pydantic.BaseModel):
"""消息"""
@@ -33,9 +43,6 @@ class Message(pydantic.BaseModel):
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None
"""内容"""
function_call: typing.Optional[FunctionCall] = None
"""函数调用,不再受支持,请使用tool_calls"""
tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用"""
@@ -43,9 +50,7 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str:
if self.content is not None:
return str(self.content)
elif self.function_call is not None:
return f'{self.function_call.name}({self.function_call.arguments})'
return str(self.role) + ": " + str(self.content)
elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}'
else:
+23 -10
View File
@@ -6,6 +6,8 @@ import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from . import entities as modelmgr_entities
from ..tools import entities as tools_entities
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
@@ -33,20 +35,31 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def initialize(self):
pass
@abc.abstractmethod
async def request(
async def preprocess(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求API
):
"""预处理
在这里处理特定API对Query对象的兼容性问题。
"""
pass
对话前文可以从 query 对象中获取。
可以多次yield消息对象。
@abc.abstractmethod
async def call(
self,
model: modelmgr_entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
"""调用API
Args:
query (core_entities.Query): 本次请求的上下文对象
model (modelmgr_entities.LLMModelInfo): 使用的模型信息
messages (typing.List[llm_entities.Message]): 消息对象列表
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
Yields:
pkg.provider.entities.Message: 返回消息对象
Returns:
llm_entities.Message: 返回消息对象
"""
raise NotImplementedError
pass
+13 -12
View File
@@ -27,20 +27,22 @@ class AnthropicMessages(api.LLMAPIRequester):
proxies=self.ap.proxy_mgr.get_forward_proxies()
)
async def request(
async def call(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
self.client.api_key = query.use_model.token_mgr.get_token()
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name
args["model"] = model.name if model.model_name is None else model.model_name
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
] + [m.dict(exclude_none=True) for m in query.messages]
req_messages = [
m.dict(exclude_none=True) for m in messages if m.content.strip() != ""
]
# 删除所有 role=system & content='' 的消息
# 删除所有 role=system & content='' 的消息
req_messages = [
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]
@@ -64,10 +66,9 @@ class AnthropicMessages(api.LLMAPIRequester):
args["messages"] = req_messages
try:
resp = await self.client.messages.create(**args)
yield llm_entities.Message(
return llm_entities.Message(
content=resp.content[0].text,
role=resp.role
)
@@ -79,4 +80,4 @@ class AnthropicMessages(api.LLMAPIRequester):
if 'model: ' in str(e):
raise errors.RequesterError(f'模型无效: {e.message}')
else:
raise errors.RequesterError(f'请求地址无效: {e.message}')
raise errors.RequesterError(f'请求地址无效: {e.message}')
+11 -65
View File
@@ -84,73 +84,19 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
message = await self._make_msg(resp)
return message
async def _request(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求"""
pending_tool_calls = []
async def call(
self,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
] + [m.dict(exclude_none=True) for m in query.messages]
m.dict(exclude_none=True) for m in messages
]
# req_messages.append({"role": "user", "content": str(query.message_chain)})
# 首次请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg.dict(exclude_none=True))
except Exception as e:
# 出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(
err_msg.dict(exclude_none=True)
)
# 处理完所有调用,继续请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
try:
async for msg in self._request(query):
yield msg
return await self._closure(req_messages, model, funcs)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:
@@ -163,6 +109,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
except openai.NotFoundError as e:
raise errors.RequesterError(f'请求路径错误: {e.message}')
except openai.RateLimitError as e:
raise errors.RequesterError(f'请求过于频繁: {e.message}')
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')