mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-03 04:24:36 +00:00
Merge pull request #685 from RockChinQ/feat/run-multi-adapter
Feat: 支持同时运行多个适配器
This commit is contained in:
@@ -34,7 +34,7 @@ class APIGroup(metaclass=abc.ABCMeta):
|
||||
headers: dict = {},
|
||||
**kwargs
|
||||
):
|
||||
self._runtime_info['account_id'] = "{}".format(self.ap.im_mgr.bot_account_id)
|
||||
self._runtime_info['account_id'] = "-1"
|
||||
|
||||
url = self.prefix + path
|
||||
data = json.dumps(data)
|
||||
|
||||
@@ -4,6 +4,8 @@ import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import aioconsole
|
||||
|
||||
from ..platform import manager as im_mgr
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
from ..provider.requester import modelmgr as llm_model_mgr
|
||||
@@ -72,11 +74,21 @@ class Application:
|
||||
tasks = []
|
||||
|
||||
try:
|
||||
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self.im_mgr.run()),
|
||||
asyncio.create_task(self.ctrl.run())
|
||||
]
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
async def interrupt(tasks):
|
||||
await asyncio.sleep(1.5)
|
||||
while await aioconsole.ainput("使用 exit 退出程序 > ") != 'exit':
|
||||
pass
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
await interrupt(tasks)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@@ -88,7 +88,10 @@ async def make_app() -> app.Application:
|
||||
},
|
||||
runtime_info={
|
||||
"admin_id": "{}".format(system_cfg.data["admin-sessions"]),
|
||||
"msg_source": platform_cfg.data["platform-adapter"],
|
||||
"msg_source": [
|
||||
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
|
||||
for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
|
||||
],
|
||||
},
|
||||
)
|
||||
ap.ctr_mgr = center_v2_api
|
||||
|
||||
@@ -70,7 +70,8 @@ class Controller:
|
||||
if result.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
query.message_event,
|
||||
result.user_notice
|
||||
result.user_notice,
|
||||
query.adapter
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
@@ -150,7 +151,7 @@ class Controller:
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}")
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
# traceback.print_exc()
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from ..provider import entities as llm_entities
|
||||
from ..provider.requester import entities
|
||||
from ..provider.sysprompt import entities as sysprompt_entities
|
||||
from ..provider.tools import entities as tools_entities
|
||||
from ..platform import adapter as msadapter
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
@@ -44,6 +45,9 @@ class Query(pydantic.BaseModel):
|
||||
message_chain: mirai.MessageChain
|
||||
"""消息链,platform收到的消息链"""
|
||||
|
||||
adapter: msadapter.MessageSourceAdapter
|
||||
"""适配器对象"""
|
||||
|
||||
session: typing.Optional[Session] = None
|
||||
"""会话对象,由前置处理器设置"""
|
||||
|
||||
@@ -68,6 +72,9 @@ class Query(pydantic.BaseModel):
|
||||
resp_message_chain: typing.Optional[mirai.MessageChain] = None
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话"""
|
||||
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
import mirai
|
||||
|
||||
from . import entities
|
||||
from ..platform import adapter as msadapter
|
||||
|
||||
|
||||
class QueryPool:
|
||||
@@ -29,7 +30,8 @@ class QueryPool:
|
||||
launcher_id: int,
|
||||
sender_id: int,
|
||||
message_event: mirai.MessageEvent,
|
||||
message_chain: mirai.MessageChain
|
||||
message_chain: mirai.MessageChain,
|
||||
adapter: msadapter.MessageSourceAdapter
|
||||
) -> entities.Query:
|
||||
async with self.condition:
|
||||
query = entities.Query(
|
||||
@@ -40,7 +42,8 @@ class QueryPool:
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
resp_messages=[],
|
||||
resp_message_chain=None
|
||||
resp_message_chain=None,
|
||||
adapter=adapter
|
||||
)
|
||||
self.queries.append(query)
|
||||
self.query_id_counter += 1
|
||||
|
||||
@@ -52,7 +52,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
||||
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
|
||||
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
|
||||
@@ -7,6 +7,7 @@ from mirai.models.message import MessageComponent, ForwardMessageNode
|
||||
from mirai.models.base import MiraiBaseModel
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
class ForwardMessageDiaplay(MiraiBaseModel):
|
||||
@@ -37,7 +38,7 @@ class Forward(MessageComponent):
|
||||
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
async def process(self, message: str) -> list[MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
display = ForwardMessageDiaplay(
|
||||
title="群聊的聊天记录",
|
||||
brief="[聊天记录]",
|
||||
@@ -48,7 +49,7 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
node_list = [
|
||||
ForwardMessageNode(
|
||||
sender_id=self.ap.im_mgr.bot_account_id,
|
||||
sender_id=query.adapter.bot_account_id,
|
||||
sender_name='QQ用户',
|
||||
message_chain=MessageChain([message])
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from mirai.models import MessageChain, Image as ImageComponent
|
||||
from mirai.models.message import MessageComponent
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
@@ -21,7 +22,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
async def initialize(self):
|
||||
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
|
||||
|
||||
async def process(self, message: str) -> list[MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time()))
|
||||
|
||||
@@ -6,6 +6,7 @@ import mirai
|
||||
from mirai.models.message import MessageComponent
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
|
||||
|
||||
class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
@@ -18,5 +19,5 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str) -> list[MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
return []
|
||||
|
||||
@@ -31,7 +31,8 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
|
||||
await self.ap.im_mgr.send(
|
||||
query.message_event,
|
||||
query.resp_message_chain
|
||||
query.resp_message_chain,
|
||||
adapter=query.adapter
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
|
||||
@@ -47,7 +47,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
use_rule = use_rule[str(query.launcher_id)]
|
||||
|
||||
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule)
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
|
||||
if res.matching:
|
||||
query.message_chain = res.replacement
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import abc
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
"""判断消息是否匹配规则
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,7 @@ import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
class AtBotRule(rule_model.GroupRespondRule):
|
||||
@@ -12,11 +13,12 @@ class AtBotRule(rule_model.GroupRespondRule):
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
|
||||
if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id))
|
||||
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(mirai.At(query.adapter.bot_account_id))
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=message_chain,
|
||||
|
||||
@@ -2,6 +2,7 @@ import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
class PrefixRule(rule_model.GroupRespondRule):
|
||||
@@ -10,7 +11,8 @@ class PrefixRule(rule_model.GroupRespondRule):
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
prefixes = rule_dict['prefix']
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
class RandomRespRule(rule_model.GroupRespondRule):
|
||||
@@ -12,7 +13,8 @@ class RandomRespRule(rule_model.GroupRespondRule):
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
random_rate = rule_dict['random']
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
class RegExpRule(rule_model.GroupRespondRule):
|
||||
@@ -12,7 +13,8 @@ class RegExpRule(rule_model.GroupRespondRule):
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
) -> entities.RuleJudgeResult:
|
||||
regexps = rule_dict['regexp']
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
|
||||
):
|
||||
"""注册事件监听器
|
||||
|
||||
@@ -84,7 +84,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
|
||||
):
|
||||
"""注销事件监听器
|
||||
|
||||
|
||||
@@ -17,11 +17,8 @@ from ..plugin import events
|
||||
# 控制QQ消息输入输出的类
|
||||
class PlatformManager:
|
||||
|
||||
adapter: msadapter.MessageSourceAdapter = None
|
||||
|
||||
@property
|
||||
def bot_account_id(self):
|
||||
return self.adapter.bot_account_id
|
||||
# adapter: msadapter.MessageSourceAdapter = None
|
||||
adapters: list[msadapter.MessageSourceAdapter] = []
|
||||
|
||||
# modern
|
||||
ap: app.Application = None
|
||||
@@ -29,27 +26,13 @@ class PlatformManager:
|
||||
def __init__(self, ap: app.Application = None):
|
||||
|
||||
self.ap = ap
|
||||
self.adapters = []
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy
|
||||
|
||||
adapter_cls = None
|
||||
|
||||
for adapter in msadapter.preregistered_adapters:
|
||||
if adapter.name == self.ap.platform_cfg.data['platform-adapter']:
|
||||
adapter_cls = adapter
|
||||
break
|
||||
if adapter_cls is None:
|
||||
raise Exception('未知的平台适配器: ' + self.ap.platform_cfg.data['platform-adapter'])
|
||||
|
||||
cfg_key = self.ap.platform_cfg.data['platform-adapter'] + '-config'
|
||||
self.adapter = adapter_cls(
|
||||
self.ap.platform_cfg.data[cfg_key],
|
||||
self.ap
|
||||
)
|
||||
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
async def on_friend_message(event: FriendMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PersonMessageReceived(
|
||||
@@ -68,15 +51,11 @@ class PlatformManager:
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
self.adapter.register_listener(
|
||||
FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
async def on_stranger_message(event: StrangerMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PersonMessageReceived(
|
||||
@@ -96,16 +75,10 @@ class PlatformManager:
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
# nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件
|
||||
if self.ap.platform_cfg.data['platform-adapter'] == 'yiri-mirai':
|
||||
self.adapter.register_listener(
|
||||
StrangerMessage,
|
||||
on_stranger_message
|
||||
)
|
||||
|
||||
async def on_group_message(event: GroupMessage):
|
||||
async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.GroupMessageReceived(
|
||||
@@ -124,15 +97,53 @@ class PlatformManager:
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
index = 0
|
||||
|
||||
self.adapter.register_listener(
|
||||
GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
for adap_cfg in self.ap.platform_cfg.data['platform-adapters']:
|
||||
if adap_cfg['enable']:
|
||||
self.ap.logger.info(f'初始化平台适配器 {index}: {adap_cfg["adapter"]}')
|
||||
index += 1
|
||||
cfg_copy = adap_cfg.copy()
|
||||
del cfg_copy['enable']
|
||||
adapter_name = cfg_copy['adapter']
|
||||
del cfg_copy['adapter']
|
||||
|
||||
async def send(self, event, msg, check_quote=True, check_at_sender=True):
|
||||
found = False
|
||||
|
||||
for adapter in msadapter.preregistered_adapters:
|
||||
if adapter.name == adapter_name:
|
||||
found = True
|
||||
adapter_cls = adapter
|
||||
|
||||
adapter_inst = adapter_cls(
|
||||
cfg_copy,
|
||||
self.ap
|
||||
)
|
||||
self.adapters.append(adapter_inst)
|
||||
|
||||
if adapter_name == 'yiri-mirai':
|
||||
adapter_inst.register_listener(
|
||||
StrangerMessage,
|
||||
on_stranger_message
|
||||
)
|
||||
|
||||
adapter_inst.register_listener(
|
||||
FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
adapter_inst.register_listener(
|
||||
GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
if not found:
|
||||
raise Exception('platform.json 中启用了未知的平台适配器: ' + adapter_name)
|
||||
|
||||
async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
|
||||
|
||||
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
|
||||
|
||||
@@ -143,7 +154,7 @@ class PlatformManager:
|
||||
)
|
||||
)
|
||||
|
||||
await self.adapter.reply_message(
|
||||
await adapter.reply_message(
|
||||
event,
|
||||
msg,
|
||||
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
|
||||
@@ -170,7 +181,21 @@ class PlatformManager:
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
await self.adapter.run_async()
|
||||
tasks = []
|
||||
for adapter in self.adapters:
|
||||
async def exception_wrapper(adapter):
|
||||
try:
|
||||
await adapter.run_async()
|
||||
except Exception as e:
|
||||
self.ap.logger.error('平台适配器运行出错: ' + str(e))
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
tasks.append(exception_wrapper(adapter))
|
||||
|
||||
for task in tasks:
|
||||
asyncio.create_task(task)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error('平台适配器运行出错: ' + str(e))
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
@@ -240,12 +240,12 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None],
|
||||
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None],
|
||||
):
|
||||
async def on_message(event: aiocqhttp.Event):
|
||||
self.bot_account_id = event.self_id
|
||||
try:
|
||||
return await callback(self.event_converter.target2yiri(event))
|
||||
return await callback(self.event_converter.target2yiri(event), self)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -257,7 +257,7 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None],
|
||||
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
# 加了之后会导致:https://github.com/Lxns-Network/nakuru-project/issues/25
|
||||
# from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
@@ -12,7 +13,6 @@ import nakuru.entities.components as nkc
|
||||
|
||||
from .. import adapter as adapter_model
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
|
||||
|
||||
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
@@ -170,11 +170,11 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
||||
|
||||
listener_list: list[dict]
|
||||
|
||||
ap: app.Application
|
||||
# ap: app.Application
|
||||
|
||||
cfg: dict
|
||||
|
||||
def __init__(self, cfg: dict, ap: app.Application):
|
||||
def __init__(self, cfg: dict, ap):
|
||||
"""初始化nakuru-project的对象"""
|
||||
cfg['port'] = cfg['ws_port']
|
||||
del cfg['ws_port']
|
||||
@@ -257,14 +257,15 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
try:
|
||||
|
||||
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
|
||||
|
||||
# 包装函数
|
||||
async def listener_wrapper(app: nakuru.CQHTTP, source: NakuruProjectAdapter.event_converter.yiri2target(event_type)):
|
||||
print(1111)
|
||||
await callback(self.event_converter.target2yiri(source))
|
||||
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls):
|
||||
await callback(self.event_converter.target2yiri(source), self)
|
||||
|
||||
# 将包装函数和原函数的对应关系存入列表
|
||||
self.listener_list.append(
|
||||
@@ -276,7 +277,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
||||
)
|
||||
|
||||
# 注册监听器
|
||||
self.bot.receiver(self.event_converter.yiri2target(event_type).__name__)(listener_wrapper)
|
||||
self.bot.receiver(source_cls.__name__)(listener_wrapper)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
@@ -284,7 +285,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
|
||||
|
||||
@@ -326,7 +327,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
||||
await self.bot._run()
|
||||
self.ap.logger.info("运行 Nakuru 适配器")
|
||||
while True:
|
||||
await asyncio.sleep(100)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def kill(self) -> bool:
|
||||
return False
|
||||
@@ -362,14 +362,14 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
|
||||
try:
|
||||
|
||||
async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage]):
|
||||
self.cached_official_messages[str(message.id)] = message
|
||||
await callback(OfficialEventConverter.target2yiri(message))
|
||||
await callback(OfficialEventConverter.target2yiri(message), self)
|
||||
|
||||
for event_handler in event_handler_mapping[event_type]:
|
||||
setattr(self.bot, event_handler, wrapper)
|
||||
@@ -380,7 +380,7 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
delattr(self.bot, event_handler_mapping[event_type])
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
"""注册事件监听器
|
||||
|
||||
@@ -95,12 +95,14 @@ class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
|
||||
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
|
||||
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件
|
||||
"""
|
||||
self.bot.on(event_type)(callback)
|
||||
async def wrapper(event: mirai.Event):
|
||||
await callback(event, self)
|
||||
self.bot.on(event_type)(wrapper)
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event], None]
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
):
|
||||
"""注销事件监听器
|
||||
|
||||
|
||||
@@ -14,3 +14,4 @@ tiktoken
|
||||
PyYaml
|
||||
aiohttp
|
||||
pydantic
|
||||
aioconsole
|
||||
@@ -1,30 +1,38 @@
|
||||
{
|
||||
"platform-adapter": "yiri-mirai",
|
||||
"yiri-mirai-config": {
|
||||
"adapter": "WebSocketAdapter",
|
||||
"host": "127.0.0.1",
|
||||
"port": 8080,
|
||||
"verifyKey": "yirimirai",
|
||||
"qq": 123456789
|
||||
},
|
||||
"nakuru-config": {
|
||||
"host": "127.0.0.1",
|
||||
"ws_port": 8080,
|
||||
"http_port": 5700,
|
||||
"token": ""
|
||||
},
|
||||
"aiocqhttp-config": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 8080
|
||||
},
|
||||
"qq-botpy-config": {
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"intents": [
|
||||
"public_guild_messages",
|
||||
"direct_message"
|
||||
]
|
||||
},
|
||||
"platform-adapters": [
|
||||
{
|
||||
"adapter": "yiri-mirai",
|
||||
"enable": false,
|
||||
"host": "127.0.0.1",
|
||||
"port": 8080,
|
||||
"verifyKey": "yirimirai",
|
||||
"qq": 123456789
|
||||
},
|
||||
{
|
||||
"adapter": "nakuru",
|
||||
"enable": false,
|
||||
"host": "127.0.0.1",
|
||||
"ws_port": 8080,
|
||||
"http_port": 5700,
|
||||
"token": ""
|
||||
},
|
||||
{
|
||||
"adapter": "aiocqhttp",
|
||||
"enable": false,
|
||||
"host": "127.0.0.1",
|
||||
"port": 8080
|
||||
},
|
||||
{
|
||||
"adapter": "qq-botpy",
|
||||
"enable": false,
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"intents": [
|
||||
"public_guild_messages",
|
||||
"direct_message"
|
||||
]
|
||||
}
|
||||
],
|
||||
"track-function-calls": true,
|
||||
"quote-origin": false,
|
||||
"at-sender": false,
|
||||
|
||||
Reference in New Issue
Block a user