From 527ad81d38aa11d869cb37687ee832c865fd3bb3 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Tue, 14 May 2024 22:20:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=A7=A3=E8=97=95chat=E7=9A=84?= =?UTF-8?q?=E5=A4=84=E7=90=86=E5=99=A8=E5=92=8C=E8=AF=B7=E6=B1=82=E5=99=A8?= =?UTF-8?q?=20(#772)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/operators/default.py | 32 ++++----- pkg/command/operators/last.py | 2 +- pkg/pipeline/bansess/bansess.py | 5 +- pkg/pipeline/cntfilter/cntfilter.py | 13 +++- pkg/pipeline/cntfilter/filters/banwords.py | 2 +- pkg/pipeline/longtext/longtext.py | 3 + pkg/pipeline/preproc/preproc.py | 15 +++- pkg/pipeline/process/handlers/chat.py | 70 +++++++++++++++++-- pkg/pipeline/process/process.py | 8 ++- pkg/pipeline/ratelimit/ratelimit.py | 5 +- pkg/pipeline/resprule/resprule.py | 5 +- pkg/pipeline/stagemgr.py | 22 +++--- pkg/pipeline/wrapper/wrapper.py | 9 ++- pkg/provider/entities.py | 17 +++-- pkg/provider/modelmgr/api.py | 33 ++++++--- pkg/provider/modelmgr/apis/anthropicmsgs.py | 25 +++---- pkg/provider/modelmgr/apis/chatcmpl.py | 76 +++------------------ 17 files changed, 205 insertions(+), 137 deletions(-) diff --git a/pkg/command/operators/default.py b/pkg/command/operators/default.py index ca7e404d..ee46c7d0 100644 --- a/pkg/command/operators/default.py +++ b/pkg/command/operators/default.py @@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator): content = "" for msg in prompt.messages: - content += f" {msg.role}: {msg.content}" + content += f" {msg.readable_str()}\n" reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" @@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator): context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) - else: - prompt_name = context.crt_params[0] - - try: - prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) - if prompt is None: - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) - else: - context.session.use_prompt_name = prompt.name - yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") - except Exception as e: - traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) + else: + prompt_name = context.crt_params[0] + + try: + prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) + if prompt is None: + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) + else: + context.session.use_prompt_name = prompt.name + yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) diff --git a/pkg/command/operators/last.py b/pkg/command/operators/last.py index 8e3a5231..e7a14c83 100644 --- a/pkg/command/operators/last.py +++ b/pkg/command/operators/last.py @@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator): context.session.using_conversation = context.session.conversations[index-1] time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") - yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}") return else: yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 95a7cffd..9c041385 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -8,7 +8,10 @@ from ...config import manager as cfg_mgr @stage.stage_class('BanSessionCheckStage') class BanSessionCheckStage(stage.PipelineStage): - """访问控制处理阶段""" + """访问控制处理阶段 + + 仅检查query中群号或个人号是否在访问控制列表中。 + """ async def initialize(self): pass diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 21b6c250..2c6a5ab9 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -14,7 +14,18 @@ from .filters import cntignore, banwords, baiduexamine @stage.stage_class('PostContentFilterStage') @stage.stage_class('PreContentFilterStage') class ContentFilterStage(stage.PipelineStage): - """内容过滤阶段""" + """内容过滤阶段 + + 前置: + 检查消息是否符合规则,不符合则拦截。 + 改写: + message_chain + + 后置: + 检查AI回复消息是否符合规则,可能进行改写,不符合则拦截。 + 改写: + query.resp_messages + """ filter_chain: list[filter_model.ContentFilter] diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 5cd7dcfa..1430c2ed 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -8,7 +8,7 @@ from ....config import manager as cfg_mgr @filter_model.filter_class("ban-word-filter") class BanWordFilter(filter_model.ContentFilter): - """根据内容禁言""" + """根据内容过滤""" async def initialize(self): pass diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 28c28143..ec0e66e4 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -16,6 +16,9 @@ from ...config import manager as cfg_mgr @stage.stage_class("LongTextProcessStage") class LongTextProcessStage(stage.PipelineStage): """长消息处理阶段 + + 改写: + - resp_message_chain """ strategy_impl: strategy.LongTextStrategy diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index cedc030f..164f78c8 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -9,6 +9,16 @@ from ...plugin import events @stage.stage_class("PreProcessor") class PreProcessor(stage.PipelineStage): """请求预处理阶段 + + 签出会话、prompt、上文、模型、内容函数。 + + 改写: + - session + - prompt + - messages + - user_message + - use_model + - use_funcs """ async def process( @@ -27,7 +37,7 @@ class PreProcessor(stage.PipelineStage): query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() - query.user_message = llm_entities.Message( + query.user_message = llm_entities.Message( # TODO 适配多模态输入 role='user', content=str(query.message_chain).strip() ) @@ -37,11 +47,10 @@ class PreProcessor(stage.PipelineStage): query.use_funcs = conversation.use_funcs # =========== 触发事件 PromptPreProcessing - session = query.session event_ctx = await self.ap.plugin_mgr.emit_event( event=events.PromptPreProcessing( - session_name=f'{session.launcher_type.value}_{session.launcher_id}', + session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}', default_prompt=query.prompt.messages, prompt=query.messages, query=query diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index f38ee341..2f0616f3 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing import time import traceback +import json import mirai @@ -70,17 +71,13 @@ class ChatMessageHandler(handler.MessageHandler): mirai.Plain(event_ctx.event.alter) ]) - query.messages.append( - query.user_message - ) - text_length = 0 start_time = time.time() try: - async for result in query.use_model.requester.request(query): + async for result in self.runner(query): query.resp_messages.append(result) self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') @@ -115,4 +112,65 @@ class ChatMessageHandler(handler.MessageHandler): model_name=query.use_model.name, response_seconds=int(time.time() - start_time), retry_times=-1, - ) \ No newline at end of file + ) + + async def runner( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """执行一个请求处理过程中的LLM接口请求、函数调用的循环 + + 这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器 + """ + await query.use_model.requester.preprocess(query) + + pending_tool_calls = [] + + req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] + + # 首次请求 + msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg) + + # 持续请求,只要还有待处理的工具调用就继续处理调用 + while pending_tool_calls: + for tool_call in pending_tool_calls: + try: + func = tool_call.function + + parameters = json.loads(func.arguments) + + func_ret = await self.ap.tool_mgr.execute_func_call( + query, func.name, parameters + ) + + msg = llm_entities.Message( + role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id + ) + + yield msg + + req_messages.append(msg) + except Exception as e: + # 工具调用出错,添加一个报错信息到 req_messages + err_msg = llm_entities.Message( + role="tool", content=f"err: {e}", tool_call_id=tool_call.id + ) + + yield err_msg + + req_messages.append(err_msg) + + # 处理完所有调用,再次请求 + msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg) diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index ddf8809e..e58d15ee 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -11,7 +11,13 @@ from ...config import manager as cfg_mgr @stage.stage_class("MessageProcessor") class Processor(stage.PipelineStage): - """请求实际处理阶段""" + """请求实际处理阶段 + + 通过命令处理器和聊天处理器处理消息。 + + 改写: + - resp_messages + """ cmd_handler: handler.MessageHandler diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index 2622247a..cd39b85c 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -11,7 +11,10 @@ from ...core import entities as core_entities @stage.stage_class("RequireRateLimitOccupancy") @stage.stage_class("ReleaseRateLimitOccupancy") class RateLimit(stage.PipelineStage): - """限速器控制阶段""" + """限速器控制阶段 + + 不改写query,只检查是否需要限速。 + """ algo: algo.ReteLimitAlgo diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index d795d056..fce0c4ec 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -14,9 +14,12 @@ from ...config import manager as cfg_mgr @stage.stage_class("GroupRespondRuleCheckStage") class GroupRespondRuleCheckStage(stage.PipelineStage): """群组响应规则检查器 + + 仅检查群消息是否符合规则。 """ rule_matchers: list[rule.GroupRespondRule] + """检查器实例""" async def initialize(self): """初始化检查器 @@ -31,7 +34,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - if query.launcher_type.value != 'group': + if query.launcher_type.value != 'group': # 只处理群消息 return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index 23c7897d..46957aad 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -17,17 +17,17 @@ from .ratelimit import ratelimit # 请求处理阶段顺序 stage_order = [ - "GroupRespondRuleCheckStage", - "BanSessionCheckStage", - "PreContentFilterStage", - "PreProcessor", - "RequireRateLimitOccupancy", - "MessageProcessor", - "ReleaseRateLimitOccupancy", - "PostContentFilterStage", - "ResponseWrapper", - "LongTextProcessStage", - "SendResponseBackStage", + "GroupRespondRuleCheckStage", # 群响应规则检查 + "BanSessionCheckStage", # 封禁会话检查 + "PreContentFilterStage", # 内容过滤前置阶段 + "PreProcessor", # 预处理器 + "RequireRateLimitOccupancy", # 请求速率限制占用 + "MessageProcessor", # 处理器 + "ReleaseRateLimitOccupancy", # 释放速率限制占用 + "PostContentFilterStage", # 内容过滤后置阶段 + "ResponseWrapper", # 响应包装器 + "LongTextProcessStage", # 长文本处理 + "SendResponseBackStage", # 发送响应 ] diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index a500d7c2..78705e6b 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -14,6 +14,13 @@ from ...plugin import events @stage.stage_class("ResponseWrapper") class ResponseWrapper(stage.PipelineStage): + """回复包装阶段 + + 把回复的 message 包装成人类识读的形式。 + + 改写: + - resp_message_chain + """ async def initialize(self): pass @@ -128,4 +135,4 @@ class ResponseWrapper(stage.PipelineStage): yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query - ) \ No newline at end of file + ) diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index a30d4e33..3281a93b 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -21,6 +21,16 @@ class ToolCall(pydantic.BaseModel): function: FunctionCall +class Content(pydantic.BaseModel): + + type: str + """内容类型""" + + text: typing.Optional[str] = None + + image_url: typing.Optional[str] = None + + class Message(pydantic.BaseModel): """消息""" @@ -33,9 +43,6 @@ class Message(pydantic.BaseModel): content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None """内容""" - function_call: typing.Optional[FunctionCall] = None - """函数调用,不再受支持,请使用tool_calls""" - tool_calls: typing.Optional[list[ToolCall]] = None """工具调用""" @@ -43,9 +50,7 @@ class Message(pydantic.BaseModel): def readable_str(self) -> str: if self.content is not None: - return str(self.content) - elif self.function_call is not None: - return f'{self.function_call.name}({self.function_call.arguments})' + return str(self.role) + ": " + str(self.content) elif self.tool_calls is not None: return f'调用工具: {self.tool_calls[0].id}' else: diff --git a/pkg/provider/modelmgr/api.py b/pkg/provider/modelmgr/api.py index 63021bed..930cf9e7 100644 --- a/pkg/provider/modelmgr/api.py +++ b/pkg/provider/modelmgr/api.py @@ -6,6 +6,8 @@ import typing from ...core import app from ...core import entities as core_entities from .. import entities as llm_entities +from . import entities as modelmgr_entities +from ..tools import entities as tools_entities preregistered_requesters: list[typing.Type[LLMAPIRequester]] = [] @@ -33,20 +35,31 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): async def initialize(self): pass - @abc.abstractmethod - async def request( + async def preprocess( self, query: core_entities.Query, - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - """请求API + ): + """预处理 + + 在这里处理特定API对Query对象的兼容性问题。 + """ + pass - 对话前文可以从 query 对象中获取。 - 可以多次yield消息对象。 + @abc.abstractmethod + async def call( + self, + model: modelmgr_entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + """调用API Args: - query (core_entities.Query): 本次请求的上下文对象 + model (modelmgr_entities.LLMModelInfo): 使用的模型信息 + messages (typing.List[llm_entities.Message]): 消息对象列表 + funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. - Yields: - pkg.provider.entities.Message: 返回消息对象 + Returns: + llm_entities.Message: 返回消息对象 """ - raise NotImplementedError + pass diff --git a/pkg/provider/modelmgr/apis/anthropicmsgs.py b/pkg/provider/modelmgr/apis/anthropicmsgs.py index 42bd3856..923e1ceb 100644 --- a/pkg/provider/modelmgr/apis/anthropicmsgs.py +++ b/pkg/provider/modelmgr/apis/anthropicmsgs.py @@ -27,20 +27,22 @@ class AnthropicMessages(api.LLMAPIRequester): proxies=self.ap.proxy_mgr.get_forward_proxies() ) - async def request( + async def call( self, - query: core_entities.Query, - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - self.client.api_key = query.use_model.token_mgr.get_token() + model: entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = model.token_mgr.get_token() args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() - args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name + args["model"] = model.name if model.model_name is None else model.model_name - req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 - m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != "" - ] + [m.dict(exclude_none=True) for m in query.messages] + req_messages = [ + m.dict(exclude_none=True) for m in messages if m.content.strip() != "" + ] - # 删除所有 role=system & content='' 的消息 + # 删除所有 role=system & content='' 的消息 req_messages = [ m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "") ] @@ -64,10 +66,9 @@ class AnthropicMessages(api.LLMAPIRequester): args["messages"] = req_messages try: - resp = await self.client.messages.create(**args) - yield llm_entities.Message( + return llm_entities.Message( content=resp.content[0].text, role=resp.role ) @@ -79,4 +80,4 @@ class AnthropicMessages(api.LLMAPIRequester): if 'model: ' in str(e): raise errors.RequesterError(f'模型无效: {e.message}') else: - raise errors.RequesterError(f'请求地址无效: {e.message}') \ No newline at end of file + raise errors.RequesterError(f'请求地址无效: {e.message}') diff --git a/pkg/provider/modelmgr/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py index e3901de5..7984dd83 100644 --- a/pkg/provider/modelmgr/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -84,73 +84,19 @@ class OpenAIChatCompletions(api.LLMAPIRequester): message = await self._make_msg(resp) return message - - async def _request( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - """请求""" - - pending_tool_calls = [] - + + async def call( + self, + model: entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 - m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != "" - ] + [m.dict(exclude_none=True) for m in query.messages] + m.dict(exclude_none=True) for m in messages + ] - # req_messages.append({"role": "user", "content": str(query.message_chain)}) - - # 首次请求 - msg = await self._closure(req_messages, query.use_model, query.use_funcs) - - yield msg - - pending_tool_calls = msg.tool_calls - - req_messages.append(msg.dict(exclude_none=True)) - - # 持续请求,只要还有待处理的工具调用就继续处理调用 - while pending_tool_calls: - for tool_call in pending_tool_calls: - try: - func = tool_call.function - - parameters = json.loads(func.arguments) - - func_ret = await self.ap.tool_mgr.execute_func_call( - query, func.name, parameters - ) - - msg = llm_entities.Message( - role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id - ) - - yield msg - - req_messages.append(msg.dict(exclude_none=True)) - except Exception as e: - # 出错,添加一个报错信息到 req_messages - err_msg = llm_entities.Message( - role="tool", content=f"err: {e}", tool_call_id=tool_call.id - ) - - yield err_msg - - req_messages.append( - err_msg.dict(exclude_none=True) - ) - - # 处理完所有调用,继续请求 - msg = await self._closure(req_messages, query.use_model, query.use_funcs) - - yield msg - - pending_tool_calls = msg.tool_calls - - req_messages.append(msg.dict(exclude_none=True)) - - async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]: try: - async for msg in self._request(query): - yield msg + return await self._closure(req_messages, model, funcs) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: @@ -163,6 +109,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester): except openai.NotFoundError as e: raise errors.RequesterError(f'请求路径错误: {e.message}') except openai.RateLimitError as e: - raise errors.RequesterError(f'请求过于频繁: {e.message}') + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}')