feat(test): add shared test factories package

Create tests/factories/ with reusable test factories:
- FakeApp: mock application with all dependencies
- Message chains: text_chain, mention_chain, image_chain
- Query factories: text_query, group_text_query, command_query, etc.

No test changes - maintains backward compatibility.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
huanghuoguoguo
2026-05-08 13:53:59 +08:00
parent 37641f05f2
commit 1af2cb5bc2
3 changed files with 484 additions and 0 deletions

View File

@@ -0,0 +1,48 @@
"""
Shared test factories for LangBot tests.
Provides reusable factories for:
- Fake application (app.py)
- Messages and queries (message.py)
- Fake providers (provider.py)
- Fake platforms (platform.py)
Usage:
from tests.factories import FakeApp, text_query, FakeProvider
app = FakeApp()
query = text_query("hello")
provider = FakeProvider.returns("response")
"""
from tests.factories.app import FakeApp, fake_app
from tests.factories.message import (
text_chain,
group_text_chain,
mention_chain,
image_chain,
text_query,
group_text_query,
private_text_query,
command_query,
mention_query,
empty_query,
)
__all__ = [
# App
"FakeApp",
"fake_app",
# Message chains
"text_chain",
"group_text_chain",
"mention_chain",
"image_chain",
# Queries
"text_query",
"group_text_query",
"private_text_query",
"command_query",
"mention_query",
"empty_query",
]

115
tests/factories/app.py Normal file
View File

@@ -0,0 +1,115 @@
"""
Fake application factory for tests.
Provides a mock Application object with all dependencies needed by pipeline stages.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, Mock
import langbot_plugin.api.entities.builtin.provider.session as provider_session
class FakeApp:
"""Mock Application object providing all basic dependencies needed by stages."""
def __init__(
self,
*,
command_prefix: list[str] = ["/", "!"],
command_enable: bool = True,
pipeline_concurrency: int = 10,
):
self.logger = self._create_mock_logger()
self.sess_mgr = self._create_mock_session_manager()
self.model_mgr = self._create_mock_model_manager()
self.tool_mgr = self._create_mock_tool_manager()
self.plugin_connector = self._create_mock_plugin_connector()
self.persistence_mgr = self._create_mock_persistence_manager()
self.query_pool = self._create_mock_query_pool()
self.instance_config = self._create_mock_instance_config(
command_prefix=command_prefix,
command_enable=command_enable,
pipeline_concurrency=pipeline_concurrency,
)
self.task_mgr = self._create_mock_task_manager()
# Captured outbound messages (for assertions)
self._outbound_messages: list = []
def _create_mock_logger(self):
logger = Mock()
logger.debug = Mock()
logger.info = Mock()
logger.error = Mock()
logger.warning = Mock()
return logger
def _create_mock_session_manager(self):
sess_mgr = AsyncMock()
sess_mgr.get_session = AsyncMock()
sess_mgr.get_conversation = AsyncMock()
return sess_mgr
def _create_mock_model_manager(self):
model_mgr = AsyncMock()
model_mgr.get_model_by_uuid = AsyncMock()
return model_mgr
def _create_mock_tool_manager(self):
tool_mgr = AsyncMock()
tool_mgr.get_all_tools = AsyncMock(return_value=[])
return tool_mgr
def _create_mock_plugin_connector(self):
plugin_connector = AsyncMock()
plugin_connector.emit_event = AsyncMock()
return plugin_connector
def _create_mock_persistence_manager(self):
persistence_mgr = AsyncMock()
persistence_mgr.execute_async = AsyncMock()
return persistence_mgr
def _create_mock_query_pool(self):
query_pool = Mock()
query_pool.cached_queries = {}
query_pool.queries = []
query_pool.condition = AsyncMock()
return query_pool
def _create_mock_instance_config(
self,
command_prefix: list[str],
command_enable: bool,
pipeline_concurrency: int,
):
instance_config = Mock()
instance_config.data = {
"command": {"prefix": command_prefix, "enable": command_enable},
"concurrency": {"pipeline": pipeline_concurrency},
}
return instance_config
def _create_mock_task_manager(self):
task_mgr = Mock()
task_mgr.create_task = Mock()
return task_mgr
def capture_message(self, message):
"""Capture an outbound message for test assertions."""
self._outbound_messages.append(message)
def get_outbound_messages(self) -> list:
"""Get all captured outbound messages."""
return self._outbound_messages.copy()
def clear_outbound_messages(self):
"""Clear captured outbound messages."""
self._outbound_messages.clear()
def fake_app(**kwargs) -> FakeApp:
"""Create a FakeApp instance with optional overrides."""
return FakeApp(**kwargs)

321
tests/factories/message.py Normal file
View File

@@ -0,0 +1,321 @@
"""
Message and query factories for tests.
Provides reusable factories for creating message chains, events, and query objects.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, Mock
import typing
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.entities.builtin.provider.session as provider_session
# Counter for generating unique IDs
_query_counter = 0
def _next_query_id() -> int:
"""Generate a unique query ID."""
global _query_counter
_query_counter += 1
return _query_counter
# ============== Message Chain Factories ==============
def text_chain(text: str = "hello") -> platform_message.MessageChain:
"""Create a simple text message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=text),
])
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
"""Create a group text message chain (same as text_chain, context provided by event)."""
return text_chain(text)
def mention_chain(
text: str = "hello",
target: typing.Union[int, str] = 12345,
) -> platform_message.MessageChain:
"""Create a message chain with @mention."""
return platform_message.MessageChain([
platform_message.At(target=target),
platform_message.Plain(text=f" {text}"),
])
def image_chain(
text: str = "",
url: str = "https://example.com/image.png",
) -> platform_message.MessageChain:
"""Create a message chain with an image."""
components = []
if text:
components.append(platform_message.Plain(text=text))
components.append(platform_message.Image(url=url))
return platform_message.MessageChain(components)
def command_chain(
command: str = "help",
prefix: str = "/",
) -> platform_message.MessageChain:
"""Create a command message chain."""
return platform_message.MessageChain([
platform_message.Plain(text=f"{prefix}{command}"),
])
# ============== Message Event Factories ==============
def friend_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
nickname: str = "TestUser",
) -> platform_events.FriendMessage:
"""Create a friend (private) message event."""
sender = platform_entities.Friend(
id=sender_id,
nickname=nickname,
remark=None,
)
return platform_events.FriendMessage(
type="FriendMessage",
sender=sender,
message_chain=message_chain,
time=1609459200,
)
def group_message_event(
message_chain: platform_message.MessageChain,
sender_id: typing.Union[int, str] = 12345,
sender_name: str = "TestUser",
group_id: typing.Union[int, str] = 99999,
group_name: str = "TestGroup",
) -> platform_events.GroupMessage:
"""Create a group message event."""
group = platform_entities.Group(
id=group_id,
name=group_name,
permission=platform_entities.Permission.Member,
)
sender = platform_entities.GroupMember(
id=sender_id,
member_name=sender_name,
permission=platform_entities.Permission.Member,
group=group,
)
return platform_events.GroupMessage(
type="GroupMessage",
sender=sender,
message_chain=message_chain,
time=1609459200,
)
# ============== Mock Adapter Factory ==============
def mock_adapter() -> Mock:
"""Create a mock platform adapter."""
adapter = AsyncMock()
adapter.is_stream_output_supported = AsyncMock(return_value=False)
adapter.reply_message = AsyncMock()
adapter.reply_message_chunk = AsyncMock()
return adapter
# ============== Query Factories ==============
def _base_query(
message_chain: platform_message.MessageChain,
message_event: platform_events.MessageEvent,
launcher_type: provider_session.LauncherTypes,
launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str],
adapter: Mock,
**overrides,
) -> pipeline_query.Query:
"""Create a base query with model_construct to bypass validation."""
query_id = _next_query_id()
base_data = {
"query_id": query_id,
"launcher_type": launcher_type,
"launcher_id": launcher_id,
"sender_id": sender_id,
"message_chain": message_chain,
"message_event": message_event,
"adapter": adapter,
"pipeline_uuid": "test-pipeline-uuid",
"bot_uuid": "test-bot-uuid",
"pipeline_config": {
"ai": {
"runner": {"runner": "local-agent"},
"local-agent": {
"model": {"primary": "test-model-uuid", "fallbacks": []},
"prompt": "test-prompt",
},
},
"output": {"misc": {"at-sender": False, "quote-origin": False}},
"trigger": {"misc": {"combine-quote-message": False}},
},
"session": None,
"prompt": None,
"messages": [],
"user_message": None,
"use_funcs": [],
"use_llm_model_uuid": None,
"variables": {},
"resp_messages": [],
"resp_message_chain": None,
"current_stage_name": None,
}
# Apply overrides
for key, value in overrides.items():
base_data[key] = value
return pipeline_query.Query.model_construct(**base_data)
def text_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a basic text query (private chat)."""
chain = text_chain(text)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def private_text_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a private text query (alias for text_query)."""
return text_query(text, sender_id, **overrides)
def group_text_query(
text: str = "hello",
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
) -> pipeline_query.Query:
"""Create a group text query."""
chain = text_chain(text)
event = group_message_event(chain, sender_id, group_id=group_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=group_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def command_query(
command: str = "help",
prefix: str = "/",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create a command-like query."""
chain = command_chain(command, prefix)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def mention_query(
text: str = "hello",
target: typing.Union[int, str] = 12345,
sender_id: typing.Union[int, str] = 12345,
group_id: typing.Union[int, str] = 99999,
**overrides,
) -> pipeline_query.Query:
"""Create a mention-bot query (group chat with @mention)."""
chain = mention_chain(text, target)
event = group_message_event(chain, sender_id, group_id=group_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=group_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)
def empty_query(**overrides) -> pipeline_query.Query:
"""Create an empty message query."""
chain = platform_message.MessageChain([])
event = friend_message_event(chain)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
adapter=adapter,
**overrides,
)
def image_query(
text: str = "",
url: str = "https://example.com/image.png",
sender_id: typing.Union[int, str] = 12345,
**overrides,
) -> pipeline_query.Query:
"""Create an image query."""
chain = image_chain(text, url)
event = friend_message_event(chain, sender_id)
adapter = mock_adapter()
return _base_query(
message_chain=chain,
message_event=event,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=sender_id,
sender_id=sender_id,
adapter=adapter,
**overrides,
)