""" Message and query factories for tests. Provides reusable factories for creating message chains, events, and query objects. """ from __future__ import annotations from unittest.mock import AsyncMock, Mock import typing import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.platform.events as platform_events import langbot_plugin.api.entities.builtin.platform.entities as platform_entities import langbot_plugin.api.entities.builtin.provider.session as provider_session # Counter for generating unique IDs _query_counter = 0 def _next_query_id() -> int: """Generate a unique query ID.""" global _query_counter _query_counter += 1 return _query_counter # ============== Message Chain Factories ============== def text_chain(text: str = "hello") -> platform_message.MessageChain: """Create a simple text message chain.""" return platform_message.MessageChain([ platform_message.Plain(text=text), ]) def group_text_chain(text: str = "hello") -> platform_message.MessageChain: """Create a group text message chain (same as text_chain, context provided by event).""" return text_chain(text) def mention_chain( text: str = "hello", target: typing.Union[int, str] = 12345, ) -> platform_message.MessageChain: """Create a message chain with @mention.""" return platform_message.MessageChain([ platform_message.At(target=target), platform_message.Plain(text=f" {text}"), ]) def image_chain( text: str = "", url: str = "https://example.com/image.png", ) -> platform_message.MessageChain: """Create a message chain with an image.""" components = [] if text: components.append(platform_message.Plain(text=text)) components.append(platform_message.Image(url=url)) return platform_message.MessageChain(components) def command_chain( command: str = "help", prefix: str = "/", ) -> platform_message.MessageChain: """Create a command message chain.""" return platform_message.MessageChain([ platform_message.Plain(text=f"{prefix}{command}"), ]) # ============== Message Event Factories ============== def friend_message_event( message_chain: platform_message.MessageChain, sender_id: typing.Union[int, str] = 12345, nickname: str = "TestUser", ) -> platform_events.FriendMessage: """Create a friend (private) message event.""" sender = platform_entities.Friend( id=sender_id, nickname=nickname, remark=None, ) return platform_events.FriendMessage( type="FriendMessage", sender=sender, message_chain=message_chain, time=1609459200, ) def group_message_event( message_chain: platform_message.MessageChain, sender_id: typing.Union[int, str] = 12345, sender_name: str = "TestUser", group_id: typing.Union[int, str] = 99999, group_name: str = "TestGroup", ) -> platform_events.GroupMessage: """Create a group message event.""" group = platform_entities.Group( id=group_id, name=group_name, permission=platform_entities.Permission.Member, ) sender = platform_entities.GroupMember( id=sender_id, member_name=sender_name, permission=platform_entities.Permission.Member, group=group, ) return platform_events.GroupMessage( type="GroupMessage", sender=sender, message_chain=message_chain, time=1609459200, ) # ============== Mock Adapter Factory ============== def mock_adapter() -> Mock: """Create a mock platform adapter.""" adapter = AsyncMock() adapter.is_stream_output_supported = AsyncMock(return_value=False) adapter.reply_message = AsyncMock() adapter.reply_message_chunk = AsyncMock() return adapter # ============== Query Factories ============== def _base_query( message_chain: platform_message.MessageChain, message_event: platform_events.MessageEvent, launcher_type: provider_session.LauncherTypes, launcher_id: typing.Union[int, str], sender_id: typing.Union[int, str], adapter: Mock, **overrides, ) -> pipeline_query.Query: """Create a base query with model_construct to bypass validation.""" query_id = _next_query_id() base_data = { "query_id": query_id, "launcher_type": launcher_type, "launcher_id": launcher_id, "sender_id": sender_id, "message_chain": message_chain, "message_event": message_event, "adapter": adapter, "pipeline_uuid": "test-pipeline-uuid", "bot_uuid": "test-bot-uuid", "pipeline_config": { "ai": { "runner": {"runner": "local-agent"}, "local-agent": { "model": {"primary": "test-model-uuid", "fallbacks": []}, "prompt": "test-prompt", }, }, "output": {"misc": {"at-sender": False, "quote-origin": False}}, "trigger": {"misc": {"combine-quote-message": False}}, }, "session": None, "prompt": None, "messages": [], "user_message": None, "use_funcs": [], "use_llm_model_uuid": None, "variables": {}, "resp_messages": [], "resp_message_chain": None, "current_stage_name": None, } # Apply overrides for key, value in overrides.items(): base_data[key] = value return pipeline_query.Query.model_construct(**base_data) def text_query( text: str = "hello", sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: """Create a basic text query (private chat).""" chain = text_chain(text) event = friend_message_event(chain, sender_id) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=sender_id, sender_id=sender_id, adapter=adapter, **overrides, ) def private_text_query( text: str = "hello", sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: """Create a private text query (alias for text_query).""" return text_query(text, sender_id, **overrides) def group_text_query( text: str = "hello", sender_id: typing.Union[int, str] = 12345, group_id: typing.Union[int, str] = 99999, **overrides, ) -> pipeline_query.Query: """Create a group text query.""" chain = text_chain(text) event = group_message_event(chain, sender_id, group_id=group_id) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.GROUP, launcher_id=group_id, sender_id=sender_id, adapter=adapter, **overrides, ) def command_query( command: str = "help", prefix: str = "/", sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: """Create a command-like query.""" chain = command_chain(command, prefix) event = friend_message_event(chain, sender_id) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=sender_id, sender_id=sender_id, adapter=adapter, **overrides, ) def mention_query( text: str = "hello", target: typing.Union[int, str] = 12345, sender_id: typing.Union[int, str] = 12345, group_id: typing.Union[int, str] = 99999, **overrides, ) -> pipeline_query.Query: """Create a mention-bot query (group chat with @mention).""" chain = mention_chain(text, target) event = group_message_event(chain, sender_id, group_id=group_id) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.GROUP, launcher_id=group_id, sender_id=sender_id, adapter=adapter, **overrides, ) def empty_query(**overrides) -> pipeline_query.Query: """Create an empty message query.""" chain = platform_message.MessageChain([]) event = friend_message_event(chain) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=12345, sender_id=12345, adapter=adapter, **overrides, ) def image_query( text: str = "", url: str = "https://example.com/image.png", sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: """Create an image query.""" chain = image_chain(text, url) event = friend_message_event(chain, sender_id) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=sender_id, sender_id=sender_id, adapter=adapter, **overrides, ) def file_query( url: str = "https://example.com/document.pdf", name: str = "document.pdf", text: str = "", sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: """Create a file attachment query.""" components = [] if text: components.append(platform_message.Plain(text=text)) components.append(platform_message.File(url=url, name=name)) chain = platform_message.MessageChain(components) event = friend_message_event(chain, sender_id) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=sender_id, sender_id=sender_id, adapter=adapter, **overrides, ) def unsupported_query( unsupported_type: str = "CustomComponent", text: str = "", sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: """Create a query with unsupported/unknown message segment.""" components = [] if text: components.append(platform_message.Plain(text=text)) # Use Unknown component for unsupported types components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}")) chain = platform_message.MessageChain(components) event = friend_message_event(chain, sender_id) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=sender_id, sender_id=sender_id, adapter=adapter, **overrides, ) def query_with_session( text: str = "hello", sender_id: typing.Union[int, str] = 12345, session: provider_session.Session = None, **overrides, ) -> pipeline_query.Query: """Create a query with a session object. If session is None, creates a default session with empty conversation. """ if session is None: # Create a default session session = provider_session.Session( launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=sender_id, sender_id=sender_id, use_prompt_name="default", using_conversation=None, conversations=[], ) return text_query(text, sender_id, session=session, **overrides) def query_with_config( text: str = "hello", sender_id: typing.Union[int, str] = 12345, pipeline_config: dict = None, **overrides, ) -> pipeline_query.Query: """Create a query with custom pipeline configuration. If pipeline_config is None, uses default config. Useful for testing specific stage behaviors. """ if pipeline_config is None: pipeline_config = { "ai": { "runner": {"runner": "local-agent"}, "local-agent": { "model": {"primary": "test-model-uuid", "fallbacks": []}, "prompt": "test-prompt", }, }, "output": {"misc": {"at-sender": False, "quote-origin": False}}, "trigger": {"misc": {"combine-quote-message": False}}, } return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides) def voice_query( url: str = "https://example.com/audio.mp3", sender_id: typing.Union[int, str] = 12345, **overrides, ) -> pipeline_query.Query: """Create a voice/audio query.""" components = [ platform_message.Voice(url=url), ] chain = platform_message.MessageChain(components) event = friend_message_event(chain, sender_id) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=sender_id, sender_id=sender_id, adapter=adapter, **overrides, ) def at_all_query( text: str = "hello", sender_id: typing.Union[int, str] = 12345, group_id: typing.Union[int, str] = 99999, **overrides, ) -> pipeline_query.Query: """Create a group query with @All mention.""" components = [ platform_message.AtAll(), platform_message.Plain(text=f" {text}"), ] chain = platform_message.MessageChain(components) event = group_message_event(chain, sender_id, group_id=group_id) adapter = mock_adapter() return _base_query( message_chain=chain, message_event=event, launcher_type=provider_session.LauncherTypes.GROUP, launcher_id=group_id, sender_id=sender_id, adapter=adapter, **overrides, )