diff --git a/pkg/api/http/service/bot.py b/pkg/api/http/service/bot.py index b8c0a46b..fcd81fcb 100644 --- a/pkg/api/http/service/bot.py +++ b/pkg/api/http/service/bot.py @@ -44,11 +44,16 @@ class BotService: async def create_bot(self, bot_data: dict) -> str: """创建机器人""" + # TODO: 检查配置信息格式 bot_data['uuid'] = str(uuid.uuid4()) await self.ap.persistence_mgr.execute_async( sqlalchemy.insert(persistence_bot.Bot).values(bot_data) ) - # TODO: 加载机器人到机器人管理器 + + bot = await self.get_bot(bot_data['uuid']) + + await self.ap.platform_mgr.load_bot(bot) + return bot_data['uuid'] async def update_bot(self, bot_uuid: str, bot_data: dict) -> None: @@ -58,13 +63,21 @@ class BotService: await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid) ) - # TODO: 加载机器人到机器人管理器 + await self.ap.platform_mgr.remove_bot(bot_uuid) + + # select from db + bot = await self.get_bot(bot_uuid) + + runtime_bot = await self.ap.platform_mgr.load_bot(bot) + + if runtime_bot.enable: + await runtime_bot.run() async def delete_bot(self, bot_uuid: str) -> None: """删除机器人""" + await self.ap.platform_mgr.remove_bot(bot_uuid) await self.ap.persistence_mgr.execute_async( sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) ) - # TODO: 从机器人管理器中删除机器人 diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index e96f014a..f6d8cde1 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -35,7 +35,10 @@ class ModelsService: **model_data ) ) - await self.ap.model_mgr.load_llm_model(model_data) + + llm_model = await self.get_llm_model(model_data['uuid']) + + await self.ap.model_mgr.load_llm_model(llm_model) return model_data['uuid'] @@ -60,7 +63,10 @@ class ModelsService: ) await self.ap.model_mgr.remove_llm_model(model_uuid) - await self.ap.model_mgr.load_llm_model(model_data) + + llm_model = await self.get_llm_model(model_uuid) + + await self.ap.model_mgr.load_llm_model(llm_model) async def delete_llm_model(self, model_uuid: str) -> None: await self.ap.persistence_mgr.execute_async( diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index d5ee7564..210ee9ad 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -6,13 +6,14 @@ import sys import logging import asyncio import traceback +import sqlalchemy from .sources import qqofficial # FriendMessage, Image, MessageChain, Plain from ..platform import adapter as msadapter -from ..core import app, entities as core_entities +from ..core import app, entities as core_entities, taskmgr from ..plugin import events from .types import message as platform_message from .types import events as platform_events @@ -20,11 +21,64 @@ from .types import entities as platform_entities from ..discover import engine +from ..entity.persistence import bot as persistence_bot + # 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题 from . import types as mirai sys.modules['mirai'] = mirai +class RuntimeBot: + """运行时机器人""" + + ap: app.Application + + bot_entity: persistence_bot.Bot + + enable: bool + + adapter: msadapter.MessagePlatformAdapter + + task_wrapper: taskmgr.TaskWrapper + + task_context: taskmgr.TaskContext + + def __init__(self, ap: app.Application, bot_entity: persistence_bot.Bot, adapter: msadapter.MessagePlatformAdapter): + self.ap = ap + self.bot_entity = bot_entity + self.enable = bot_entity.enable + self.adapter = adapter + self.task_context = taskmgr.TaskContext() + + async def run(self): + + async def exception_wrapper(): + try: + self.task_context.set_current_action('Running...') + await self.adapter.run_async() + self.task_context.set_current_action('Exited.') + except Exception as e: + if isinstance(e, asyncio.CancelledError): + self.task_context.set_current_action('Exited.') + return + self.task_context.set_current_action('Exited with error.') + self.task_context.log(f'平台适配器运行出错: {e}') + self.task_context.log(f"Traceback: {traceback.format_exc()}") + self.ap.logger.error(f'平台适配器运行出错: {e}') + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + + self.task_wrapper = self.ap.task_mgr.create_task( + exception_wrapper(), + kind="platform-adapter", + name=f"platform-adapter-{self.adapter.__class__.__name__}", + context=self.task_context, + scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM] + ) + + async def shutdown(self): + await self.adapter.kill() + + # 控制QQ消息输入输出的类 class PlatformManager: @@ -33,22 +87,55 @@ class PlatformManager: message_platform_adapter_components: list[engine.Component] = [] - # modern + # ====== 4.0 ====== ap: app.Application = None + bots: list[RuntimeBot] + + adapter_components: list[engine.Component] + + adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] + def __init__(self, ap: app.Application = None): self.ap = ap self.adapters = [] + self.bots = [] + self.adapter_components = [] + self.adapter_dict = {} async def initialize(self): - components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter') + self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter') + adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {} + for component in self.adapter_components: + adapter_dict[component.metadata.name] = component.get_python_component_class() + self.adapter_dict = adapter_dict - self.message_platform_adapter_components = components + await self.load_bots_from_db() - # from .sources import nakuru, aiocqhttp, qqbotpy, qqofficial, wecom, lark, discord, gewechat, officialaccount, telegram, dingtalk + async def load_bots_from_db(self): + self.ap.logger.info('Loading bots from db...') + self.bots = [] + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_bot.Bot) + ) + + bots = result.all() + + for bot in bots: + # load all bots here, enable or disable will be handled in runtime + await self.load_bot(bot) + + async def load_bot(self, bot_entity: persistence_bot.Bot | sqlalchemy.Row[persistence_bot.Bot] | dict) -> RuntimeBot: + """加载机器人""" + if isinstance(bot_entity, sqlalchemy.Row): + bot_entity = persistence_bot.Bot(**bot_entity._mapping) + elif isinstance(bot_entity, dict): + bot_entity = persistence_bot.Bot(**bot_entity) + async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter): await self.ap.query_pool.add_query( @@ -70,45 +157,44 @@ class PlatformManager: message_chain=event.message_chain, adapter=adapter ) - - index = 0 - for adap_cfg in self.ap.platform_cfg.data['platform-adapters']: - if adap_cfg['enable']: - self.ap.logger.info(f'初始化平台适配器 {index}: {adap_cfg["adapter"]}') - index += 1 - cfg_copy = adap_cfg.copy() - del cfg_copy['enable'] - adapter_name = cfg_copy['adapter'] - del cfg_copy['adapter'] + adapter_inst = self.adapter_dict[bot_entity.adapter]( + bot_entity.adapter_config, + self.ap + ) - found = False + adapter_inst.register_listener( + platform_events.FriendMessage, + on_friend_message + ) + adapter_inst.register_listener( + platform_events.GroupMessage, + on_group_message + ) - for adapter in self.message_platform_adapter_components: - if adapter.metadata.name == adapter_name: - found = True - adapter_cls = adapter.get_python_component_class() - - adapter_inst = adapter_cls( - cfg_copy, - self.ap - ) - self.adapters.append(adapter_inst) + runtime_bot = RuntimeBot( + ap=self.ap, + bot_entity=bot_entity, + adapter=adapter_inst + ) - adapter_inst.register_listener( - platform_events.FriendMessage, - on_friend_message - ) - adapter_inst.register_listener( - platform_events.GroupMessage, - on_group_message - ) - - if not found: - raise Exception('platform.json 中启用了未知的平台适配器: ' + adapter_name) - - if len(self.adapters) == 0: - self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。') + self.bots.append(runtime_bot) + + return runtime_bot + + async def get_bot_by_uuid(self, bot_uuid: str) -> RuntimeBot | None: + for bot in self.bots: + if bot.bot_entity.uuid == bot_uuid: + return bot + return None + + async def remove_bot(self, bot_uuid: str): + for bot in self.bots: + if bot.bot_entity.uuid == bot_uuid: + if bot.enable: + await bot.shutdown() + self.bots.remove(bot) + return def get_available_adapters_info(self) -> list[dict]: return [ @@ -168,35 +254,14 @@ class PlatformManager: quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False ) - async def run(self): - try: - tasks = [] - for adapter in self.adapters: - async def exception_wrapper(adapter: msadapter.MessagePlatformAdapter): - try: - await adapter.run_async() - except Exception as e: - if isinstance(e, asyncio.CancelledError): - return - self.ap.logger.error('平台适配器运行出错: ' + str(e)) - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + async def run(self): + # This method will only be called when the application launching + for bot in self.bots: + if bot.enable: + await bot.run() - tasks.append(exception_wrapper(adapter)) - - - for task in tasks: - self.ap.task_mgr.create_task( - task, - kind="platform-adapter", - name=f"platform-adapter-{adapter.__class__.__name__}", - scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM], - ) - - except Exception as e: - self.ap.logger.error('平台适配器运行出错: ' + str(e)) - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - async def shutdown(self): - for adapter in self.adapters: - await adapter.kill() + for bot in self.bots: + if bot.enable: + await bot.shutdown() self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM) \ No newline at end of file