diff --git a/pkg/config/manager.py b/pkg/config/manager.py index e343b0c2..7e52d7b0 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from . import model as file_model from ..utils import context from .impls import pymodule, json as json_file diff --git a/pkg/core/app.py b/pkg/core/app.py index 8c0a0c58..c0ce12fa 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -5,6 +5,9 @@ import asyncio from ..qqbot import manager as qqbot_mgr from ..openai import manager as openai_mgr +from ..openai.session import sessionmgr as llm_session_mgr +from ..openai.requester import modelmgr as llm_model_mgr +from ..openai.sysprompt import sysprompt as llm_prompt_mgr from ..config import manager as config_mgr from ..database import manager as database_mgr from ..utils.center import v2 as center_mgr @@ -18,6 +21,12 @@ class Application: llm_mgr: openai_mgr.OpenAIInteract = None + sess_mgr: llm_session_mgr.SessionManager = None + + model_mgr: llm_model_mgr.ModelManager = None + + prompt_mgr: llm_prompt_mgr.PromptManager = None + cfg_mgr: config_mgr.ConfigManager = None tips_mgr: config_mgr.ConfigManager = None diff --git a/pkg/core/boot.py b/pkg/core/boot.py index 10fc51b3..a74615ec 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -15,7 +15,9 @@ from ..pipeline import stagemgr from ..audit import identifier from ..database import manager as db_mgr from ..openai import manager as llm_mgr -from ..openai import session as llm_session +from ..openai.session import sessionmgr as llm_session_mgr +from ..openai.requester import modelmgr as llm_model_mgr +from ..openai.sysprompt import sysprompt as llm_prompt_mgr from ..openai import dprompt as llm_dprompt from ..qqbot import manager as im_mgr from ..qqbot.cmds import aamgr as im_cmd_aamgr @@ -112,8 +114,18 @@ async def make_app() -> app.Application: llm_mgr_inst = llm_mgr.OpenAIInteract(ap) ap.llm_mgr = llm_mgr_inst - # TODO make it async - llm_session.load_sessions() + + llm_model_mgr_inst = llm_model_mgr.ModelManager(ap) + await llm_model_mgr_inst.initialize() + ap.model_mgr = llm_model_mgr_inst + + llm_session_mgr_inst = llm_session_mgr.SessionManager(ap) + await llm_session_mgr_inst.initialize() + ap.sess_mgr = llm_session_mgr_inst + + llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap) + await llm_prompt_mgr_inst.initialize() + ap.prompt_mgr = llm_prompt_mgr_inst im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap) await im_mgr_inst.initialize() diff --git a/pkg/core/bootutils/config.py b/pkg/core/bootutils/config.py index f1471ae5..0addff08 100644 --- a/pkg/core/bootutils/config.py +++ b/pkg/core/bootutils/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from ...config import manager as config_mgr diff --git a/pkg/core/controller.py b/pkg/core/controller.py index 2470cbbd..ada46f73 100644 --- a/pkg/core/controller.py +++ b/pkg/core/controller.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import typing import traceback from . import app, entities @@ -24,25 +25,115 @@ class Controller: async def consumer(self): """事件处理循环 """ - while True: - selected_query: entities.Query = None + try: + while True: + selected_query: entities.Query = None - # 取请求 - async with self.ap.query_pool: - queries: list[entities.Query] = self.ap.query_pool.queries + # 取请求 + async with self.ap.query_pool: + queries: list[entities.Query] = self.ap.query_pool.queries - if queries: - selected_query = queries.pop(0) # FCFS - else: - await self.ap.query_pool.condition.wait() - continue + for query in queries: + session = await self.ap.sess_mgr.get_session(query) + self.ap.logger.debug(f"Checking query {query} session {session}") - if selected_query: - async def _process_query(selected_query): - async with self.semaphore: - await self.process_query(selected_query) - - asyncio.create_task(_process_query(selected_query)) + if not session.semaphore.locked(): + selected_query = query + await session.semaphore.acquire() + + break + + if selected_query: # 找到了 + queries.remove(selected_query) + else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限 + await self.ap.query_pool.condition.wait() + continue + + if selected_query: + async def _process_query(selected_query): + async with self.semaphore: # 总并发上限 + await self.process_query(selected_query) + + async with self.ap.query_pool: + (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() + # 通知其他协程,有新的请求可以处理了 + self.ap.query_pool.condition.notify_all() + + asyncio.create_task(_process_query(selected_query)) + except Exception as e: + self.ap.logger.error(f"事件处理循环出错: {e}") + traceback.print_exc() + + async def _check_output(self, result: pipeline_entities.StageProcessResult): + """检查输出 + """ + if result.user_notice: + await self.ap.im_mgr.send( + result.user_notice + ) + if result.debug_notice: + self.ap.logger.debug(result.debug_notice) + if result.console_notice: + self.ap.logger.info(result.console_notice) + + async def _execute_from_stage( + self, + stage_index: int, + query: entities.Query, + ): + """从指定阶段开始执行 + + 如何看懂这里为什么这么写? + 去问 GPT-4: + Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None], + 如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result, + 调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器 + Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage: + + A B C D E F G + + 如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是: + + A B C D E F G + + 现在假设C返回的是AsyncGenerator,那么执行顺序是: + + A B C D E F G C D E F G C D E F G ... + Q3: 但是如果不止一个stage会返回生成器呢? + """ + i = stage_index + + while i < len(self.ap.stage_mgr.stage_containers): + stage_container = self.ap.stage_mgr.stage_containers[i] + + result = await stage_container.inst.process(query, stage_container.inst_name) + + + if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}") + await self._check_output(result) + + if result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif result.result_type == pipeline_entities.ResultType.CONTINUE: + query = result.new_query + elif isinstance(result, typing.AsyncGenerator): # 生成器 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen") + + async for sub_result in result: + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}") + await self._check_output(sub_result) + + if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: + query = sub_result.new_query + await self._execute_from_stage(i + 1, query) + break + + i += 1 async def process_query(self, query: entities.Query): """处理请求 @@ -50,28 +141,7 @@ class Controller: self.ap.logger.debug(f"Processing query {query}") try: - for stage_container in self.ap.stage_mgr.stage_containers: - res = await stage_container.inst.process(query, stage_container.inst_name) - - self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}") - - if res.user_notice: - await self.ap.im_mgr.send( - query.message_event, - res.user_notice - ) - if res.debug_notice: - self.ap.logger.debug(res.debug_notice) - if res.console_notice: - self.ap.logger.info(res.console_notice) - - if res.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") - break - elif res.result_type == pipeline_entities.ResultType.CONTINUE: - query = res.new_query - continue - + await self._execute_from_stage(0, query) except Exception as e: self.ap.logger.error(f"处理请求时出错 {query}: {e}") traceback.print_exc() diff --git a/pkg/openai/entities.py b/pkg/openai/entities.py new file mode 100644 index 00000000..58f48d95 --- /dev/null +++ b/pkg/openai/entities.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import typing +import enum +import pydantic + + +class MessageRole(enum.Enum): + + SYSTEM = 'system' + + USER = 'user' + + ASSISTANT = 'assistant' + + FUNCTION = 'function' + + +class FunctionCall(pydantic.BaseModel): + name: str + + args: dict[str, typing.Any] + + +class Message(pydantic.BaseModel): + + role: MessageRole + + content: typing.Optional[str] = None + + function_call: typing.Optional[FunctionCall] = None diff --git a/pkg/openai/requester/__init__.py b/pkg/openai/requester/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/requester/api.py b/pkg/openai/requester/api.py new file mode 100644 index 00000000..5dd0abf2 --- /dev/null +++ b/pkg/openai/requester/api.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import abc +import typing + +from ...core import app +from ...core import entities as core_entities +from .. import entities as llm_entities +from ..session import entities as session_entities + +class LLMAPIRequester(metaclass=abc.ABCMeta): + """LLM API请求器 + """ + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def request( + self, + query: core_entities.Query, + conversation: session_entities.Conversation, + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """请求 + """ + raise NotImplementedError diff --git a/pkg/openai/requester/apis/__init__.py b/pkg/openai/requester/apis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/requester/apis/chatcmpl.py b/pkg/openai/requester/apis/chatcmpl.py new file mode 100644 index 00000000..5b6d2297 --- /dev/null +++ b/pkg/openai/requester/apis/chatcmpl.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import asyncio +import typing + +import openai + +from .. import api +from ....core import entities as core_entities +from ... import entities as llm_entities +from ...session import entities as session_entities + + +class OpenAIChatCompletion(api.LLMAPIRequester): + + client: openai.Client + + async def initialize(self): + self.client = openai.Client( + base_url=self.ap.cfg_mgr.data['openai_config']['reverse_proxy'], + timeout=self.ap.cfg_mgr.data['process_message_timeout'] + ) + + async def request(self, query: core_entities.Query, conversation: session_entities.Conversation) -> typing.AsyncGenerator[llm_entities.Message, None]: + """请求 + """ + await asyncio.sleep(10) + + yield llm_entities.Message( + role=llm_entities.MessageRole.ASSISTANT, + content="hello" + ) diff --git a/pkg/openai/requester/entities.py b/pkg/openai/requester/entities.py new file mode 100644 index 00000000..adc86677 --- /dev/null +++ b/pkg/openai/requester/entities.py @@ -0,0 +1,23 @@ +import typing + +import pydantic + +from . import api +from . import token + + +class LLMModelInfo(pydantic.BaseModel): + """模型""" + + name: str + + provider: str + + token_mgr: token.TokenManager + + requester: api.LLMAPIRequester + + function_call_supported: typing.Optional[bool] = False + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/openai/requester/modelmgr.py b/pkg/openai/requester/modelmgr.py new file mode 100644 index 00000000..cc606b03 --- /dev/null +++ b/pkg/openai/requester/modelmgr.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from . import entities +from ...core import app + +from .apis import chatcmpl +from . import token + + +class ModelManager: + + ap: app.Application + + model_list: list[entities.LLMModelInfo] + + def __init__(self, ap: app.Application): + self.ap = ap + self.model_list = [] + + async def initialize(self): + openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap) + openai_token_mgr = token.TokenManager(self.ap, self.ap.cfg_mgr.data['openai_config']['api_key'].values()) + + self.model_list.append( + entities.LLMModelInfo( + name="gpt-3.5-turbo", + provider="openai", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + function_call_supported=True + ) + ) + + async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: + """通过名称获取模型 + """ + for model in self.model_list: + if model.name == name: + return model + raise ValueError(f"Model {name} not found") \ No newline at end of file diff --git a/pkg/openai/requester/token.py b/pkg/openai/requester/token.py new file mode 100644 index 00000000..9277c1a6 --- /dev/null +++ b/pkg/openai/requester/token.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import typing + +import pydantic + + +class TokenManager(): + + provider: str + + tokens: list[str] + + using_token_index: typing.Optional[int] = 0 + + def __init__(self, provider: str, tokens: list[str]): + self.provider = provider + self.tokens = tokens + self.using_token_index = 0 + + def get_token(self) -> str: + return self.tokens[self.using_token_index] + + def next_token(self): + self.using_token_index = (self.using_token_index + 1) % len(self.tokens) diff --git a/pkg/openai/session/__init__.py b/pkg/openai/session/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/session/entities.py b/pkg/openai/session/entities.py new file mode 100644 index 00000000..49ddb845 --- /dev/null +++ b/pkg/openai/session/entities.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import datetime +import asyncio +import typing + +import pydantic + +from ..sysprompt import entities as sysprompt_entities +from .. import entities as llm_entities +from ..requester import entities +from ...core import entities as core_entities + + +class Conversation(pydantic.BaseModel): + """对话""" + + prompt: sysprompt_entities.Prompt + + messages: list[llm_entities.Message] + + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + use_model: entities.LLMModelInfo + + +class Session(pydantic.BaseModel): + """会话""" + launcher_type: core_entities.LauncherTypes + + launcher_id: int + + sender_id: typing.Optional[int] = 0 + + use_prompt_name: typing.Optional[str] = 'default' + + using_conversation: typing.Optional[Conversation] = None + + conversations: typing.Optional[list[Conversation]] = [] + + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + semaphore: typing.Optional[asyncio.Semaphore] = None + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/openai/session/sessionmgr.py b/pkg/openai/session/sessionmgr.py new file mode 100644 index 00000000..8aff6e02 --- /dev/null +++ b/pkg/openai/session/sessionmgr.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import asyncio + +from ...core import app, entities as core_entities +from . import entities + + +class SessionManager: + + ap: app.Application + + session_list: list[entities.Session] + + def __init__(self, ap: app.Application): + self.ap = ap + self.session_list = [] + + async def initialize(self): + pass + + async def get_session(self, query: core_entities.Query) -> entities.Session: + """获取会话 + """ + for session in self.session_list: + if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: + return session + + session = entities.Session( + launcher_type=query.launcher_type, + launcher_id=query.launcher_id, + semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000) + ) + self.session_list.append(session) + return session + + async def get_conversation(self, session: entities.Session) -> entities.Conversation: + if not session.conversations: + session.conversations = [] + + if session.using_conversation is None: + conversation = entities.Conversation( + prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), + messages=[], + use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']), + ) + session.conversations.append(conversation) + session.using_conversation = conversation + + return session.using_conversation diff --git a/pkg/openai/sysprompt/__init__.py b/pkg/openai/sysprompt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/sysprompt/entities.py b/pkg/openai/sysprompt/entities.py new file mode 100644 index 00000000..43cd3bf7 --- /dev/null +++ b/pkg/openai/sysprompt/entities.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import typing +import pydantic + +from ...openai import entities + + +class Prompt(pydantic.BaseModel): + """供AI使用的Prompt""" + + name: str + + messages: list[entities.Message] diff --git a/pkg/openai/sysprompt/loader.py b/pkg/openai/sysprompt/loader.py new file mode 100644 index 00000000..ca9e8730 --- /dev/null +++ b/pkg/openai/sysprompt/loader.py @@ -0,0 +1,32 @@ +from __future__ import annotations +import abc + +from ...core import app +from . import entities + + +class PromptLoader(metaclass=abc.ABCMeta): + """Prompt加载器抽象类 + """ + + ap: app.Application + + prompts: list[entities.Prompt] + + def __init__(self, ap: app.Application): + self.ap = ap + self.prompts = [] + + async def initialize(self): + pass + + @abc.abstractmethod + async def load(self): + """加载Prompt + """ + raise NotImplementedError + + def get_prompts(self) -> list[entities.Prompt]: + """获取Prompt列表 + """ + return self.prompts diff --git a/pkg/openai/sysprompt/loaders/__init__.py b/pkg/openai/sysprompt/loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/sysprompt/loaders/scenario.py b/pkg/openai/sysprompt/loaders/scenario.py new file mode 100644 index 00000000..e0c2bd33 --- /dev/null +++ b/pkg/openai/sysprompt/loaders/scenario.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import json +import os + +from .. import loader +from .. import entities +from ....openai import entities as llm_entities + + +class ScenarioPromptLoader(loader.PromptLoader): + """加载scenario目录下的json""" + + async def load(self): + """加载Prompt + """ + for file in os.listdir("scenarios"): + with open("scenarios/{}".format(file), "r", encoding="utf-8") as f: + file_str = f.read() + file_name = file.split(".")[0] + file_json = json.loads(file_str) + messages = [] + for msg in file_json["prompt"]: + role = llm_entities.MessageRole.SYSTEM + if "role" in msg: + if msg["role"] == "user": + role = llm_entities.MessageRole.USER + elif msg["role"] == "system": + role = llm_entities.MessageRole.SYSTEM + elif msg["role"] == "function": + role = llm_entities.MessageRole.FUNCTION + messages.append( + llm_entities.Message( + role=role, + content=msg['content'], + ) + ) + prompt = entities.Prompt( + name=file_name, + messages=messages + ) + self.prompts.append(prompt) + \ No newline at end of file diff --git a/pkg/openai/sysprompt/loaders/single.py b/pkg/openai/sysprompt/loaders/single.py new file mode 100644 index 00000000..ad37d878 --- /dev/null +++ b/pkg/openai/sysprompt/loaders/single.py @@ -0,0 +1,42 @@ +from __future__ import annotations +import os + +from .. import loader +from .. import entities +from ....openai import entities as llm_entities + + +class SingleSystemPromptLoader(loader.PromptLoader): + """配置文件中的单条system prompt的prompt加载器 + """ + + async def load(self): + """加载Prompt + """ + + for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items(): + prompt = entities.Prompt( + name=name, + messages=[ + llm_entities.Message( + role=llm_entities.MessageRole.SYSTEM, + content=cnt + ) + ] + ) + self.prompts.append(prompt) + + for file in os.listdir("prompts"): + with open("prompts/{}".format(file), "r", encoding="utf-8") as f: + file_str = f.read() + file_name = file.split(".")[0] + prompt = entities.Prompt( + name=file_name, + messages=[ + llm_entities.Message( + role=llm_entities.MessageRole.SYSTEM, + content=file_str + ) + ] + ) + self.prompts.append(prompt) diff --git a/pkg/openai/sysprompt/sysprompt.py b/pkg/openai/sysprompt/sysprompt.py new file mode 100644 index 00000000..050f6639 --- /dev/null +++ b/pkg/openai/sysprompt/sysprompt.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from ...core import app +from . import loader +from .loaders import single, scenario + + +class PromptManager: + + ap: app.Application + + loader_inst: loader.PromptLoader + + default_prompt: str = 'default' + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + + loader_map = { + "normal": single.SingleSystemPromptLoader, + "full_scenario": scenario.ScenarioPromptLoader + } + + loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']] + + self.loader_inst: loader.PromptLoader = loader_cls(self.ap) + + await self.loader_inst.initialize() + await self.loader_inst.load() + + def get_all_prompts(self) -> list[loader.entities.Prompt]: + """获取所有Prompt + """ + return self.loader_inst.get_prompts() + + async def get_prompt(self, name: str) -> loader.entities.Prompt: + """获取Prompt + """ + for prompt in self.get_all_prompts(): + if prompt.name == name: + return prompt diff --git a/pkg/pipeline/process/__init__.py b/pkg/pipeline/process/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/process/handler.py b/pkg/pipeline/process/handler.py new file mode 100644 index 00000000..6d19e039 --- /dev/null +++ b/pkg/pipeline/process/handler.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import abc + +from ...core import app +from ...core import entities as core_entities +from .. import entities + + +class MessageHandler(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def handle( + self, + query: core_entities.Query, + ) -> entities.StageProcessResult: + raise NotImplementedError diff --git a/pkg/pipeline/process/handlers/__init__.py b/pkg/pipeline/process/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py new file mode 100644 index 00000000..ebe958bf --- /dev/null +++ b/pkg/pipeline/process/handlers/chat.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import typing + +import mirai + +from .. import handler +from ... import entities +from ....core import entities as core_entities + + +class ChatMessageHandler(handler.MessageHandler): + + async def handle( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: + """处理 + """ + # 取session + # 取conversation + # 调API + # 生成器 + session = await self.ap.sess_mgr.get_session(query) + + conversation = await self.ap.sess_mgr.get_conversation(session) + + async for result in conversation.use_model.requester.request(query, conversation): + query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + + + diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py new file mode 100644 index 00000000..c5fecb67 --- /dev/null +++ b/pkg/pipeline/process/handlers/command.py @@ -0,0 +1,35 @@ +from __future__ import annotations +import typing + +import mirai + +from .. import handler +from ... import entities +from ....core import entities as core_entities + + +class CommandHandler(handler.MessageHandler): + + async def handle( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: + """处理 + """ + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain('CommandHandler') + ]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain('The Second Message') + ]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py new file mode 100644 index 00000000..29051431 --- /dev/null +++ b/pkg/pipeline/process/process.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from ...core import app, entities as core_entities +from . import handler +from .handlers import chat, command +from .. import entities +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("MessageProcessor") +class Processor(stage.PipelineStage): + + cmd_handler: handler.MessageHandler + + chat_handler: handler.MessageHandler + + async def initialize(self): + self.cmd_handler = command.CommandHandler(self.ap) + self.chat_handler = chat.ChatMessageHandler(self.ap) + + await self.cmd_handler.initialize() + await self.chat_handler.initialize() + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> entities.StageProcessResult: + """处理 + """ + message_text = str(query.message_chain).strip() + + if message_text.startswith('!') or message_text.startswith('!'): + return self.cmd_handler.handle(query) + else: + return self.chat_handler.handle(query) diff --git a/pkg/pipeline/respback/__init__.py b/pkg/pipeline/respback/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py new file mode 100644 index 00000000..4dbddaa5 --- /dev/null +++ b/pkg/pipeline/respback/respback.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import mirai + +from ...core import app + +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr + + +@stage.stage_class("SendResponseBackStage") +class SendResponseBackStage(stage.PipelineStage): + """发送响应消息 + """ + + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + """处理 + """ + + await self.ap.im_mgr.send( + query.message_event, + query.resp_message_chain + ) + + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py index 84a0339d..56c092b5 100644 --- a/pkg/pipeline/stage.py +++ b/pkg/pipeline/stage.py @@ -1,6 +1,7 @@ from __future__ import annotations import abc +import typing from ..core import app, entities as core_entities from . import entities @@ -37,7 +38,10 @@ class PipelineStage(metaclass=abc.ABCMeta): self, query: core_entities.Query, stage_inst_name: str, - ) -> entities.StageProcessResult: + ) -> typing.Union[ + entities.StageProcessResult, + typing.AsyncGenerator[entities.StageProcessResult, None], + ]: """处理 """ raise NotImplementedError diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index f5407a2e..1ff36329 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -7,7 +7,20 @@ from . import stage from .resprule import resprule from .bansess import bansess from .cntfilter import cntfilter +from .process import process from .longtext import longtext +from .respback import respback + + +stage_order = [ + "GroupRespondRuleCheckStage", + "BanSessionCheckStage", + "PreContentFilterStage", + "MessageProcessor", + "PostContentFilterStage", + "LongTextProcessStage", + "SendResponseBackStage", +] class StageInstContainer(): @@ -45,3 +58,6 @@ class StageManager: for stage_containers in self.stage_containers: await stage_containers.inst.initialize() + + # 按照 stage_order 排序 + self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name)) diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index b16450e8..7794663a 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -18,10 +18,6 @@ from ..plugin import host as plugin_host from ..plugin import models as plugin_models import tips as tips_custom from ..qqbot import adapter as msadapter -from .resprule import resprule -from .bansess import bansess -from .cntfilter import cntfilter -from .longtext import longtext from .ratelim import ratelim from ..core import app, entities as core_entities @@ -41,30 +37,18 @@ class QQBotManager: # modern ap: app.Application = None - bansess_mgr: bansess.SessionBanManager = None - cntfilter_mgr: cntfilter.ContentFilterManager = None - longtext_pcs: longtext.LongTextProcessor = None - resprule_chkr: resprule.GroupRespondRuleChecker = None ratelimiter: ratelim.RateLimiter = None def __init__(self, first_time_init=True, ap: app.Application = None): config = context.get_config_manager().data self.ap = ap - self.bansess_mgr = bansess.SessionBanManager(ap) - self.cntfilter_mgr = cntfilter.ContentFilterManager(ap) - self.longtext_pcs = longtext.LongTextProcessor(ap) - self.resprule_chkr = resprule.GroupRespondRuleChecker(ap) self.ratelimiter = ratelim.RateLimiter(ap) self.timeout = config['process_message_timeout'] self.retry = config['retry_times'] async def initialize(self): - await self.bansess_mgr.initialize() - await self.cntfilter_mgr.initialize() - await self.longtext_pcs.initialize() - await self.resprule_chkr.initialize() await self.ratelimiter.initialize() config = context.get_config_manager().data diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 65de8d52..a8359be5 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -15,7 +15,7 @@ from ..plugin import host as plugin_host from ..plugin import models as plugin_models import tips as tips_custom from ..core import app -from .cntfilter import entities +# from .cntfilter import entities processing = []