Merge pull request #685 from RockChinQ/feat/run-multi-adapter

Feat: 支持同时运行多个适配器
This commit is contained in:
Junyan Qin
2024-02-12 13:38:56 +08:00
committed by GitHub
25 changed files with 190 additions and 114 deletions

View File

@@ -34,7 +34,7 @@ class APIGroup(metaclass=abc.ABCMeta):
headers: dict = {},
**kwargs
):
self._runtime_info['account_id'] = "{}".format(self.ap.im_mgr.bot_account_id)
self._runtime_info['account_id'] = "-1"
url = self.prefix + path
data = json.dumps(data)

View File

@@ -4,6 +4,8 @@ import logging
import asyncio
import traceback
import aioconsole
from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr
@@ -72,11 +74,21 @@ class Application:
tasks = []
try:
tasks = [
asyncio.create_task(self.im_mgr.run()),
asyncio.create_task(self.ctrl.run())
]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
async def interrupt(tasks):
await asyncio.sleep(1.5)
while await aioconsole.ainput("使用 exit 退出程序 > ") != 'exit':
pass
for task in tasks:
task.cancel()
await interrupt(tasks)
except asyncio.CancelledError:
pass

View File

@@ -88,7 +88,10 @@ async def make_app() -> app.Application:
},
runtime_info={
"admin_id": "{}".format(system_cfg.data["admin-sessions"]),
"msg_source": platform_cfg.data["platform-adapter"],
"msg_source": [
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
],
},
)
ap.ctr_mgr = center_v2_api

View File

@@ -70,7 +70,8 @@ class Controller:
if result.user_notice:
await self.ap.im_mgr.send(
query.message_event,
result.user_notice
result.user_notice,
query.adapter
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
@@ -150,7 +151,7 @@ class Controller:
except Exception as e:
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
# traceback.print_exc()
traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")

View File

@@ -12,6 +12,7 @@ from ..provider import entities as llm_entities
from ..provider.requester import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
class LauncherTypes(enum.Enum):
@@ -44,6 +45,9 @@ class Query(pydantic.BaseModel):
message_chain: mirai.MessageChain
"""消息链platform收到的消息链"""
adapter: msadapter.MessageSourceAdapter
"""适配器对象"""
session: typing.Optional[Session] = None
"""会话对象,由前置处理器设置"""
@@ -68,6 +72,9 @@ class Query(pydantic.BaseModel):
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链从resp_messages包装而得"""
class Config:
arbitrary_types_allowed = True
class Conversation(pydantic.BaseModel):
"""对话"""

View File

@@ -5,6 +5,7 @@ import asyncio
import mirai
from . import entities
from ..platform import adapter as msadapter
class QueryPool:
@@ -29,7 +30,8 @@ class QueryPool:
launcher_id: int,
sender_id: int,
message_event: mirai.MessageEvent,
message_chain: mirai.MessageChain
message_chain: mirai.MessageChain,
adapter: msadapter.MessageSourceAdapter
) -> entities.Query:
async with self.condition:
query = entities.Query(
@@ -40,7 +42,8 @@ class QueryPool:
message_event=message_event,
message_chain=message_chain,
resp_messages=[],
resp_message_chain=None
resp_message_chain=None,
adapter=adapter
)
self.queries.append(query)
self.query_id_counter += 1

View File

@@ -52,7 +52,7 @@ class LongTextProcessStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@@ -7,6 +7,7 @@ from mirai.models.message import MessageComponent, ForwardMessageNode
from mirai.models.base import MiraiBaseModel
from .. import strategy as strategy_model
from ....core import entities as core_entities
class ForwardMessageDiaplay(MiraiBaseModel):
@@ -37,7 +38,7 @@ class Forward(MessageComponent):
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
@@ -48,7 +49,7 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
node_list = [
ForwardMessageNode(
sender_id=self.ap.im_mgr.bot_account_id,
sender_id=query.adapter.bot_account_id,
sender_name='QQ用户',
message_chain=MessageChain([message])
)

View File

@@ -12,6 +12,7 @@ from mirai.models import MessageChain, Image as ImageComponent
from mirai.models.message import MessageComponent
from .. import strategy as strategy_model
from ....core import entities as core_entities
class Text2ImageStrategy(strategy_model.LongTextStrategy):
@@ -21,7 +22,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
async def process(self, message: str) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))

View File

@@ -6,6 +6,7 @@ import mirai
from mirai.models.message import MessageComponent
from ...core import app
from ...core import entities as core_entities
class LongTextStrategy(metaclass=abc.ABCMeta):
@@ -18,5 +19,5 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(self, message: str) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
return []

View File

@@ -31,7 +31,8 @@ class SendResponseBackStage(stage.PipelineStage):
await self.ap.im_mgr.send(
query.message_event,
query.resp_message_chain
query.resp_message_chain,
adapter=query.adapter
)
return entities.StageProcessResult(

View File

@@ -47,7 +47,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
use_rule = use_rule[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule)
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
if res.matching:
query.message_chain = res.replacement

View File

@@ -3,7 +3,7 @@ import abc
import mirai
from ...core import app
from ...core import app, entities as core_entities
from . import entities
@@ -24,7 +24,8 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则
"""

View File

@@ -4,6 +4,7 @@ import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
class AtBotRule(rule_model.GroupRespondRule):
@@ -12,11 +13,12 @@ class AtBotRule(rule_model.GroupRespondRule):
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,

View File

@@ -2,6 +2,7 @@ import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
class PrefixRule(rule_model.GroupRespondRule):
@@ -10,7 +11,8 @@ class PrefixRule(rule_model.GroupRespondRule):
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix']

View File

@@ -4,6 +4,7 @@ import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
class RandomRespRule(rule_model.GroupRespondRule):
@@ -12,7 +13,8 @@ class RandomRespRule(rule_model.GroupRespondRule):
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random']

View File

@@ -4,6 +4,7 @@ import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
class RegExpRule(rule_model.GroupRespondRule):
@@ -12,7 +13,8 @@ class RegExpRule(rule_model.GroupRespondRule):
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp']

View File

@@ -71,7 +71,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
):
"""注册事件监听器
@@ -84,7 +84,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
):
"""注销事件监听器

View File

@@ -17,11 +17,8 @@ from ..plugin import events
# 控制QQ消息输入输出的类
class PlatformManager:
adapter: msadapter.MessageSourceAdapter = None
@property
def bot_account_id(self):
return self.adapter.bot_account_id
# adapter: msadapter.MessageSourceAdapter = None
adapters: list[msadapter.MessageSourceAdapter] = []
# modern
ap: app.Application = None
@@ -29,27 +26,13 @@ class PlatformManager:
def __init__(self, ap: app.Application = None):
self.ap = ap
self.adapters = []
async def initialize(self):
from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy
adapter_cls = None
for adapter in msadapter.preregistered_adapters:
if adapter.name == self.ap.platform_cfg.data['platform-adapter']:
adapter_cls = adapter
break
if adapter_cls is None:
raise Exception('未知的平台适配器: ' + self.ap.platform_cfg.data['platform-adapter'])
cfg_key = self.ap.platform_cfg.data['platform-adapter'] + '-config'
self.adapter = adapter_cls(
self.ap.platform_cfg.data[cfg_key],
self.ap
)
async def on_friend_message(event: FriendMessage):
async def on_friend_message(event: FriendMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived(
@@ -68,15 +51,11 @@ class PlatformManager:
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
message_chain=event.message_chain,
adapter=adapter
)
self.adapter.register_listener(
FriendMessage,
on_friend_message
)
async def on_stranger_message(event: StrangerMessage):
async def on_stranger_message(event: StrangerMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived(
@@ -96,16 +75,10 @@ class PlatformManager:
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
# nakuru不区分好友和陌生人故仅为yirimirai注册陌生人事件
if self.ap.platform_cfg.data['platform-adapter'] == 'yiri-mirai':
self.adapter.register_listener(
StrangerMessage,
on_stranger_message
)
async def on_group_message(event: GroupMessage):
async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.GroupMessageReceived(
@@ -124,15 +97,53 @@ class PlatformManager:
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
message_chain=event.message_chain,
adapter=adapter
)
index = 0
self.adapter.register_listener(
GroupMessage,
on_group_message
)
for adap_cfg in self.ap.platform_cfg.data['platform-adapters']:
if adap_cfg['enable']:
self.ap.logger.info(f'初始化平台适配器 {index}: {adap_cfg["adapter"]}')
index += 1
cfg_copy = adap_cfg.copy()
del cfg_copy['enable']
adapter_name = cfg_copy['adapter']
del cfg_copy['adapter']
async def send(self, event, msg, check_quote=True, check_at_sender=True):
found = False
for adapter in msadapter.preregistered_adapters:
if adapter.name == adapter_name:
found = True
adapter_cls = adapter
adapter_inst = adapter_cls(
cfg_copy,
self.ap
)
self.adapters.append(adapter_inst)
if adapter_name == 'yiri-mirai':
adapter_inst.register_listener(
StrangerMessage,
on_stranger_message
)
adapter_inst.register_listener(
FriendMessage,
on_friend_message
)
adapter_inst.register_listener(
GroupMessage,
on_group_message
)
if not found:
raise Exception('platform.json 中启用了未知的平台适配器: ' + adapter_name)
async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
@@ -143,7 +154,7 @@ class PlatformManager:
)
)
await self.adapter.reply_message(
await adapter.reply_message(
event,
msg,
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
@@ -170,7 +181,21 @@ class PlatformManager:
async def run(self):
try:
await self.adapter.run_async()
tasks = []
for adapter in self.adapters:
async def exception_wrapper(adapter):
try:
await adapter.run_async()
except Exception as e:
self.ap.logger.error('平台适配器运行出错: ' + str(e))
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
tasks.append(exception_wrapper(adapter))
for task in tasks:
asyncio.create_task(task)
except Exception as e:
self.ap.logger.error('平台适配器运行出错: ' + str(e))
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")

View File

@@ -240,12 +240,12 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None],
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None],
):
async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id
try:
return await callback(self.event_converter.target2yiri(event))
return await callback(self.event_converter.target2yiri(event), self)
except:
traceback.print_exc()
@@ -257,7 +257,7 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None],
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
# 加了之后会导致https://github.com/Lxns-Network/nakuru-project/issues/25
# from __future__ import annotations
import asyncio
import typing
@@ -12,7 +13,6 @@ import nakuru.entities.components as nkc
from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward
from ...core import app
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
@@ -170,11 +170,11 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
listener_list: list[dict]
ap: app.Application
# ap: app.Application
cfg: dict
def __init__(self, cfg: dict, ap: app.Application):
def __init__(self, cfg: dict, ap):
"""初始化nakuru-project的对象"""
cfg['port'] = cfg['ws_port']
del cfg['ws_port']
@@ -257,14 +257,15 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
try:
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
# 包装函数
async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)):
print(1111)
await callback(self.event_converter.target2yiri(source))
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls):
await callback(self.event_converter.target2yiri(source), self)
# 将包装函数和原函数的对应关系存入列表
self.listener_list.append(
@@ -276,7 +277,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
)
# 注册监听器
self.bot.receiver(self.event_converter.yiri2target(event_type).__name__)(listener_wrapper)
self.bot.receiver(source_cls.__name__)(listener_wrapper)
except Exception as e:
traceback.print_exc()
raise e
@@ -284,7 +285,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
@@ -326,7 +327,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
await self.bot._run()
self.ap.logger.info("运行 Nakuru 适配器")
while True:
await asyncio.sleep(100)
await asyncio.sleep(1)
def kill(self) -> bool:
return False

View File

@@ -362,14 +362,14 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
try:
async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage]):
self.cached_official_messages[str(message.id)] = message
await callback(OfficialEventConverter.target2yiri(message))
await callback(OfficialEventConverter.target2yiri(message), self)
for event_handler in event_handler_mapping[event_type]:
setattr(self.bot, event_handler, wrapper)
@@ -380,7 +380,7 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
delattr(self.bot, event_handler_mapping[event_type])

View File

@@ -87,7 +87,7 @@ class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
"""注册事件监听器
@@ -95,12 +95,14 @@ class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
"""
self.bot.on(event_type)(callback)
async def wrapper(event: mirai.Event):
await callback(event, self)
self.bot.on(event_type)(wrapper)
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event], None]
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
"""注销事件监听器

View File

@@ -14,3 +14,4 @@ tiktoken
PyYaml
aiohttp
pydantic
aioconsole

View File

@@ -1,30 +1,38 @@
{
"platform-adapter": "yiri-mirai",
"yiri-mirai-config": {
"adapter": "WebSocketAdapter",
"host": "127.0.0.1",
"port": 8080,
"verifyKey": "yirimirai",
"qq": 123456789
},
"nakuru-config": {
"host": "127.0.0.1",
"ws_port": 8080,
"http_port": 5700,
"token": ""
},
"aiocqhttp-config": {
"host": "127.0.0.1",
"port": 8080
},
"qq-botpy-config": {
"appid": "",
"secret": "",
"intents": [
"public_guild_messages",
"direct_message"
]
},
"platform-adapters": [
{
"adapter": "yiri-mirai",
"enable": false,
"host": "127.0.0.1",
"port": 8080,
"verifyKey": "yirimirai",
"qq": 123456789
},
{
"adapter": "nakuru",
"enable": false,
"host": "127.0.0.1",
"ws_port": 8080,
"http_port": 5700,
"token": ""
},
{
"adapter": "aiocqhttp",
"enable": false,
"host": "127.0.0.1",
"port": 8080
},
{
"adapter": "qq-botpy",
"enable": false,
"appid": "",
"secret": "",
"intents": [
"public_guild_messages",
"direct_message"
]
}
],
"track-function-calls": true,
"quote-origin": false,
"at-sender": false,