feat: 解藕chat的处理器和请求器 (#772)

This commit is contained in:
RockChinQ
2024-05-14 22:20:31 +08:00
parent 972d3c18af
commit 527ad81d38
17 changed files with 205 additions and 137 deletions

View File

@@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator):
content = ""
for msg in prompt.messages:
content += f" {msg.role}: {msg.content}"
content += f" {msg.readable_str()}\n"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
@@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator):
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else:
prompt_name = context.crt_params[0]
try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else:
context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else:
prompt_name = context.crt_params[0]
try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else:
context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator):
context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -8,7 +8,10 @@ from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage):
"""访问控制处理阶段"""
"""访问控制处理阶段
仅检查query中群号或个人号是否在访问控制列表中。
"""
async def initialize(self):
pass

View File

@@ -14,7 +14,18 @@ from .filters import cntignore, banwords, baiduexamine
@stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段"""
"""内容过滤阶段
前置:
检查消息是否符合规则,不符合则拦截。
改写:
message_chain
后置:
检查AI回复消息是否符合规则可能进行改写不符合则拦截。
改写:
query.resp_messages
"""
filter_chain: list[filter_model.ContentFilter]

View File

@@ -8,7 +8,7 @@ from ....config import manager as cfg_mgr
@filter_model.filter_class("ban-word-filter")
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""
"""根据内容过滤"""
async def initialize(self):
pass

View File

@@ -16,6 +16,9 @@ from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage):
"""长消息处理阶段
改写:
- resp_message_chain
"""
strategy_impl: strategy.LongTextStrategy

View File

@@ -9,6 +9,16 @@ from ...plugin import events
@stage.stage_class("PreProcessor")
class PreProcessor(stage.PipelineStage):
"""请求预处理阶段
签出会话、prompt、上文、模型、内容函数。
改写:
- session
- prompt
- messages
- user_message
- use_model
- use_funcs
"""
async def process(
@@ -27,7 +37,7 @@ class PreProcessor(stage.PipelineStage):
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
query.user_message = llm_entities.Message(
query.user_message = llm_entities.Message( # TODO 适配多模态输入
role='user',
content=str(query.message_chain).strip()
)
@@ -37,11 +47,10 @@ class PreProcessor(stage.PipelineStage):
query.use_funcs = conversation.use_funcs
# =========== 触发事件 PromptPreProcessing
session = query.session
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing(
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages,
prompt=query.messages,
query=query

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import typing
import time
import traceback
import json
import mirai
@@ -70,17 +71,13 @@ class ChatMessageHandler(handler.MessageHandler):
mirai.Plain(event_ctx.event.alter)
])
query.messages.append(
query.user_message
)
text_length = 0
start_time = time.time()
try:
async for result in query.use_model.requester.request(query):
async for result in self.runner(query):
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
@@ -115,4 +112,65 @@ class ChatMessageHandler(handler.MessageHandler):
model_name=query.use_model.name,
response_seconds=int(time.time() - start_time),
retry_times=-1,
)
)
async def runner(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
这是临时处理方案后续可能改为使用LangChain或者自研的工作流处理器
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)

View File

@@ -11,7 +11,13 @@ from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage):
"""请求实际处理阶段"""
"""请求实际处理阶段
通过命令处理器和聊天处理器处理消息。
改写:
- resp_messages
"""
cmd_handler: handler.MessageHandler

View File

@@ -11,7 +11,10 @@ from ...core import entities as core_entities
@stage.stage_class("RequireRateLimitOccupancy")
@stage.stage_class("ReleaseRateLimitOccupancy")
class RateLimit(stage.PipelineStage):
"""限速器控制阶段"""
"""限速器控制阶段
不改写query只检查是否需要限速。
"""
algo: algo.ReteLimitAlgo

View File

@@ -14,9 +14,12 @@ from ...config import manager as cfg_mgr
@stage.stage_class("GroupRespondRuleCheckStage")
class GroupRespondRuleCheckStage(stage.PipelineStage):
"""群组响应规则检查器
仅检查群消息是否符合规则。
"""
rule_matchers: list[rule.GroupRespondRule]
"""检查器实例"""
async def initialize(self):
"""初始化检查器
@@ -31,7 +34,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group':
if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@@ -17,17 +17,17 @@ from .ratelimit import ratelimit
# 请求处理阶段顺序
stage_order = [
"GroupRespondRuleCheckStage",
"BanSessionCheckStage",
"PreContentFilterStage",
"PreProcessor",
"RequireRateLimitOccupancy",
"MessageProcessor",
"ReleaseRateLimitOccupancy",
"PostContentFilterStage",
"ResponseWrapper",
"LongTextProcessStage",
"SendResponseBackStage",
"GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", # 预处理器
"RequireRateLimitOccupancy", # 请求速率限制占用
"MessageProcessor", # 处理器
"ReleaseRateLimitOccupancy", # 释放速率限制占用
"PostContentFilterStage", # 内容过滤后置阶段
"ResponseWrapper", # 响应包装器
"LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
]

View File

@@ -14,6 +14,13 @@ from ...plugin import events
@stage.stage_class("ResponseWrapper")
class ResponseWrapper(stage.PipelineStage):
"""回复包装阶段
把回复的 message 包装成人类识读的形式。
改写:
- resp_message_chain
"""
async def initialize(self):
pass
@@ -128,4 +135,4 @@ class ResponseWrapper(stage.PipelineStage):
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
)

View File

@@ -21,6 +21,16 @@ class ToolCall(pydantic.BaseModel):
function: FunctionCall
class Content(pydantic.BaseModel):
type: str
"""内容类型"""
text: typing.Optional[str] = None
image_url: typing.Optional[str] = None
class Message(pydantic.BaseModel):
"""消息"""
@@ -33,9 +43,6 @@ class Message(pydantic.BaseModel):
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None
"""内容"""
function_call: typing.Optional[FunctionCall] = None
"""函数调用不再受支持请使用tool_calls"""
tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用"""
@@ -43,9 +50,7 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str:
if self.content is not None:
return str(self.content)
elif self.function_call is not None:
return f'{self.function_call.name}({self.function_call.arguments})'
return str(self.role) + ": " + str(self.content)
elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}'
else:

View File

@@ -6,6 +6,8 @@ import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from . import entities as modelmgr_entities
from ..tools import entities as tools_entities
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
@@ -33,20 +35,31 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def initialize(self):
pass
@abc.abstractmethod
async def request(
async def preprocess(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求API
):
"""预处理
在这里处理特定API对Query对象的兼容性问题。
"""
pass
对话前文可以从 query 对象中获取。
可以多次yield消息对象。
@abc.abstractmethod
async def call(
self,
model: modelmgr_entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
"""调用API
Args:
query (core_entities.Query): 本次请求的上下文对象
model (modelmgr_entities.LLMModelInfo): 使用的模型信息
messages (typing.List[llm_entities.Message]): 消息对象列表
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
Yields:
pkg.provider.entities.Message: 返回消息对象
Returns:
llm_entities.Message: 返回消息对象
"""
raise NotImplementedError
pass

View File

@@ -27,20 +27,22 @@ class AnthropicMessages(api.LLMAPIRequester):
proxies=self.ap.proxy_mgr.get_forward_proxies()
)
async def request(
async def call(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
self.client.api_key = query.use_model.token_mgr.get_token()
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name
args["model"] = model.name if model.model_name is None else model.model_name
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
] + [m.dict(exclude_none=True) for m in query.messages]
req_messages = [
m.dict(exclude_none=True) for m in messages if m.content.strip() != ""
]
# 删除所有 role=system & content='' 的消息
# 删除所有 role=system & content='' 的消息
req_messages = [
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]
@@ -64,10 +66,9 @@ class AnthropicMessages(api.LLMAPIRequester):
args["messages"] = req_messages
try:
resp = await self.client.messages.create(**args)
yield llm_entities.Message(
return llm_entities.Message(
content=resp.content[0].text,
role=resp.role
)
@@ -79,4 +80,4 @@ class AnthropicMessages(api.LLMAPIRequester):
if 'model: ' in str(e):
raise errors.RequesterError(f'模型无效: {e.message}')
else:
raise errors.RequesterError(f'请求地址无效: {e.message}')
raise errors.RequesterError(f'请求地址无效: {e.message}')

View File

@@ -84,73 +84,19 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
message = await self._make_msg(resp)
return message
async def _request(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求"""
pending_tool_calls = []
async def call(
self,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
] + [m.dict(exclude_none=True) for m in query.messages]
m.dict(exclude_none=True) for m in messages
]
# req_messages.append({"role": "user", "content": str(query.message_chain)})
# 首次请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg.dict(exclude_none=True))
except Exception as e:
# 出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(
err_msg.dict(exclude_none=True)
)
# 处理完所有调用,继续请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
try:
async for msg in self._request(query):
yield msg
return await self._closure(req_messages, model, funcs)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:
@@ -163,6 +109,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
except openai.NotFoundError as e:
raise errors.RequesterError(f'请求路径错误: {e.message}')
except openai.RateLimitError as e:
raise errors.RequesterError(f'请求过于频繁: {e.message}')
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')