From 1af2cb5bc2f31279613815a69cf5a239b4672498 Mon Sep 17 00:00:00 2001 From: huanghuoguoguo <1051233107@qq.com> Date: Fri, 8 May 2026 13:53:59 +0800 Subject: [PATCH] feat(test): add shared test factories package Create tests/factories/ with reusable test factories: - FakeApp: mock application with all dependencies - Message chains: text_chain, mention_chain, image_chain - Query factories: text_query, group_text_query, command_query, etc. No test changes - maintains backward compatibility. Co-Authored-By: Claude Opus 4.7 --- tests/factories/__init__.py | 48 ++++++ tests/factories/app.py | 115 +++++++++++++ tests/factories/message.py | 321 ++++++++++++++++++++++++++++++++++++ 3 files changed, 484 insertions(+) create mode 100644 tests/factories/__init__.py create mode 100644 tests/factories/app.py create mode 100644 tests/factories/message.py diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py new file mode 100644 index 00000000..a2963799 --- /dev/null +++ b/tests/factories/__init__.py @@ -0,0 +1,48 @@ +""" +Shared test factories for LangBot tests. + +Provides reusable factories for: +- Fake application (app.py) +- Messages and queries (message.py) +- Fake providers (provider.py) +- Fake platforms (platform.py) + +Usage: + from tests.factories import FakeApp, text_query, FakeProvider + + app = FakeApp() + query = text_query("hello") + provider = FakeProvider.returns("response") +""" + +from tests.factories.app import FakeApp, fake_app +from tests.factories.message import ( + text_chain, + group_text_chain, + mention_chain, + image_chain, + text_query, + group_text_query, + private_text_query, + command_query, + mention_query, + empty_query, +) + +__all__ = [ + # App + "FakeApp", + "fake_app", + # Message chains + "text_chain", + "group_text_chain", + "mention_chain", + "image_chain", + # Queries + "text_query", + "group_text_query", + "private_text_query", + "command_query", + "mention_query", + "empty_query", +] \ No newline at end of file diff --git a/tests/factories/app.py b/tests/factories/app.py new file mode 100644 index 00000000..80b0650a --- /dev/null +++ b/tests/factories/app.py @@ -0,0 +1,115 @@ +""" +Fake application factory for tests. + +Provides a mock Application object with all dependencies needed by pipeline stages. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, Mock + +import langbot_plugin.api.entities.builtin.provider.session as provider_session + + +class FakeApp: + """Mock Application object providing all basic dependencies needed by stages.""" + + def __init__( + self, + *, + command_prefix: list[str] = ["/", "!"], + command_enable: bool = True, + pipeline_concurrency: int = 10, + ): + self.logger = self._create_mock_logger() + self.sess_mgr = self._create_mock_session_manager() + self.model_mgr = self._create_mock_model_manager() + self.tool_mgr = self._create_mock_tool_manager() + self.plugin_connector = self._create_mock_plugin_connector() + self.persistence_mgr = self._create_mock_persistence_manager() + self.query_pool = self._create_mock_query_pool() + self.instance_config = self._create_mock_instance_config( + command_prefix=command_prefix, + command_enable=command_enable, + pipeline_concurrency=pipeline_concurrency, + ) + self.task_mgr = self._create_mock_task_manager() + + # Captured outbound messages (for assertions) + self._outbound_messages: list = [] + + def _create_mock_logger(self): + logger = Mock() + logger.debug = Mock() + logger.info = Mock() + logger.error = Mock() + logger.warning = Mock() + return logger + + def _create_mock_session_manager(self): + sess_mgr = AsyncMock() + sess_mgr.get_session = AsyncMock() + sess_mgr.get_conversation = AsyncMock() + return sess_mgr + + def _create_mock_model_manager(self): + model_mgr = AsyncMock() + model_mgr.get_model_by_uuid = AsyncMock() + return model_mgr + + def _create_mock_tool_manager(self): + tool_mgr = AsyncMock() + tool_mgr.get_all_tools = AsyncMock(return_value=[]) + return tool_mgr + + def _create_mock_plugin_connector(self): + plugin_connector = AsyncMock() + plugin_connector.emit_event = AsyncMock() + return plugin_connector + + def _create_mock_persistence_manager(self): + persistence_mgr = AsyncMock() + persistence_mgr.execute_async = AsyncMock() + return persistence_mgr + + def _create_mock_query_pool(self): + query_pool = Mock() + query_pool.cached_queries = {} + query_pool.queries = [] + query_pool.condition = AsyncMock() + return query_pool + + def _create_mock_instance_config( + self, + command_prefix: list[str], + command_enable: bool, + pipeline_concurrency: int, + ): + instance_config = Mock() + instance_config.data = { + "command": {"prefix": command_prefix, "enable": command_enable}, + "concurrency": {"pipeline": pipeline_concurrency}, + } + return instance_config + + def _create_mock_task_manager(self): + task_mgr = Mock() + task_mgr.create_task = Mock() + return task_mgr + + def capture_message(self, message): + """Capture an outbound message for test assertions.""" + self._outbound_messages.append(message) + + def get_outbound_messages(self) -> list: + """Get all captured outbound messages.""" + return self._outbound_messages.copy() + + def clear_outbound_messages(self): + """Clear captured outbound messages.""" + self._outbound_messages.clear() + + +def fake_app(**kwargs) -> FakeApp: + """Create a FakeApp instance with optional overrides.""" + return FakeApp(**kwargs) \ No newline at end of file diff --git a/tests/factories/message.py b/tests/factories/message.py new file mode 100644 index 00000000..ecebe2f9 --- /dev/null +++ b/tests/factories/message.py @@ -0,0 +1,321 @@ +""" +Message and query factories for tests. + +Provides reusable factories for creating message chains, events, and query objects. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, Mock +import typing + +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.entities.builtin.provider.session as provider_session + + +# Counter for generating unique IDs +_query_counter = 0 + + +def _next_query_id() -> int: + """Generate a unique query ID.""" + global _query_counter + _query_counter += 1 + return _query_counter + + +# ============== Message Chain Factories ============== + + +def text_chain(text: str = "hello") -> platform_message.MessageChain: + """Create a simple text message chain.""" + return platform_message.MessageChain([ + platform_message.Plain(text=text), + ]) + + +def group_text_chain(text: str = "hello") -> platform_message.MessageChain: + """Create a group text message chain (same as text_chain, context provided by event).""" + return text_chain(text) + + +def mention_chain( + text: str = "hello", + target: typing.Union[int, str] = 12345, +) -> platform_message.MessageChain: + """Create a message chain with @mention.""" + return platform_message.MessageChain([ + platform_message.At(target=target), + platform_message.Plain(text=f" {text}"), + ]) + + +def image_chain( + text: str = "", + url: str = "https://example.com/image.png", +) -> platform_message.MessageChain: + """Create a message chain with an image.""" + components = [] + if text: + components.append(platform_message.Plain(text=text)) + components.append(platform_message.Image(url=url)) + return platform_message.MessageChain(components) + + +def command_chain( + command: str = "help", + prefix: str = "/", +) -> platform_message.MessageChain: + """Create a command message chain.""" + return platform_message.MessageChain([ + platform_message.Plain(text=f"{prefix}{command}"), + ]) + + +# ============== Message Event Factories ============== + + +def friend_message_event( + message_chain: platform_message.MessageChain, + sender_id: typing.Union[int, str] = 12345, + nickname: str = "TestUser", +) -> platform_events.FriendMessage: + """Create a friend (private) message event.""" + sender = platform_entities.Friend( + id=sender_id, + nickname=nickname, + remark=None, + ) + return platform_events.FriendMessage( + type="FriendMessage", + sender=sender, + message_chain=message_chain, + time=1609459200, + ) + + +def group_message_event( + message_chain: platform_message.MessageChain, + sender_id: typing.Union[int, str] = 12345, + sender_name: str = "TestUser", + group_id: typing.Union[int, str] = 99999, + group_name: str = "TestGroup", +) -> platform_events.GroupMessage: + """Create a group message event.""" + group = platform_entities.Group( + id=group_id, + name=group_name, + permission=platform_entities.Permission.Member, + ) + sender = platform_entities.GroupMember( + id=sender_id, + member_name=sender_name, + permission=platform_entities.Permission.Member, + group=group, + ) + return platform_events.GroupMessage( + type="GroupMessage", + sender=sender, + message_chain=message_chain, + time=1609459200, + ) + + +# ============== Mock Adapter Factory ============== + + +def mock_adapter() -> Mock: + """Create a mock platform adapter.""" + adapter = AsyncMock() + adapter.is_stream_output_supported = AsyncMock(return_value=False) + adapter.reply_message = AsyncMock() + adapter.reply_message_chunk = AsyncMock() + return adapter + + +# ============== Query Factories ============== + + +def _base_query( + message_chain: platform_message.MessageChain, + message_event: platform_events.MessageEvent, + launcher_type: provider_session.LauncherTypes, + launcher_id: typing.Union[int, str], + sender_id: typing.Union[int, str], + adapter: Mock, + **overrides, +) -> pipeline_query.Query: + """Create a base query with model_construct to bypass validation.""" + query_id = _next_query_id() + + base_data = { + "query_id": query_id, + "launcher_type": launcher_type, + "launcher_id": launcher_id, + "sender_id": sender_id, + "message_chain": message_chain, + "message_event": message_event, + "adapter": adapter, + "pipeline_uuid": "test-pipeline-uuid", + "bot_uuid": "test-bot-uuid", + "pipeline_config": { + "ai": { + "runner": {"runner": "local-agent"}, + "local-agent": { + "model": {"primary": "test-model-uuid", "fallbacks": []}, + "prompt": "test-prompt", + }, + }, + "output": {"misc": {"at-sender": False, "quote-origin": False}}, + "trigger": {"misc": {"combine-quote-message": False}}, + }, + "session": None, + "prompt": None, + "messages": [], + "user_message": None, + "use_funcs": [], + "use_llm_model_uuid": None, + "variables": {}, + "resp_messages": [], + "resp_message_chain": None, + "current_stage_name": None, + } + + # Apply overrides + for key, value in overrides.items(): + base_data[key] = value + + return pipeline_query.Query.model_construct(**base_data) + + +def text_query( + text: str = "hello", + sender_id: typing.Union[int, str] = 12345, + **overrides, +) -> pipeline_query.Query: + """Create a basic text query (private chat).""" + chain = text_chain(text) + event = friend_message_event(chain, sender_id) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=sender_id, + sender_id=sender_id, + adapter=adapter, + **overrides, + ) + + +def private_text_query( + text: str = "hello", + sender_id: typing.Union[int, str] = 12345, + **overrides, +) -> pipeline_query.Query: + """Create a private text query (alias for text_query).""" + return text_query(text, sender_id, **overrides) + + +def group_text_query( + text: str = "hello", + sender_id: typing.Union[int, str] = 12345, + group_id: typing.Union[int, str] = 99999, + **overrides, +) -> pipeline_query.Query: + """Create a group text query.""" + chain = text_chain(text) + event = group_message_event(chain, sender_id, group_id=group_id) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.GROUP, + launcher_id=group_id, + sender_id=sender_id, + adapter=adapter, + **overrides, + ) + + +def command_query( + command: str = "help", + prefix: str = "/", + sender_id: typing.Union[int, str] = 12345, + **overrides, +) -> pipeline_query.Query: + """Create a command-like query.""" + chain = command_chain(command, prefix) + event = friend_message_event(chain, sender_id) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=sender_id, + sender_id=sender_id, + adapter=adapter, + **overrides, + ) + + +def mention_query( + text: str = "hello", + target: typing.Union[int, str] = 12345, + sender_id: typing.Union[int, str] = 12345, + group_id: typing.Union[int, str] = 99999, + **overrides, +) -> pipeline_query.Query: + """Create a mention-bot query (group chat with @mention).""" + chain = mention_chain(text, target) + event = group_message_event(chain, sender_id, group_id=group_id) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.GROUP, + launcher_id=group_id, + sender_id=sender_id, + adapter=adapter, + **overrides, + ) + + +def empty_query(**overrides) -> pipeline_query.Query: + """Create an empty message query.""" + chain = platform_message.MessageChain([]) + event = friend_message_event(chain) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=12345, + sender_id=12345, + adapter=adapter, + **overrides, + ) + + +def image_query( + text: str = "", + url: str = "https://example.com/image.png", + sender_id: typing.Union[int, str] = 12345, + **overrides, +) -> pipeline_query.Query: + """Create an image query.""" + chain = image_chain(text, url) + event = friend_message_event(chain, sender_id) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=sender_id, + sender_id=sender_id, + adapter=adapter, + **overrides, + ) \ No newline at end of file