diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py index bbd430db..aec13741 100644 --- a/tests/factories/__init__.py +++ b/tests/factories/__init__.py @@ -28,6 +28,12 @@ from tests.factories.message import ( mention_query, empty_query, image_query, + file_query, + unsupported_query, + voice_query, + at_all_query, + query_with_session, + query_with_config, ) from tests.factories.provider import ( FakeProvider, @@ -64,6 +70,12 @@ __all__ = [ "mention_query", "empty_query", "image_query", + "file_query", + "unsupported_query", + "voice_query", + "at_all_query", + "query_with_session", + "query_with_config", # Provider "FakeProvider", "fake_provider", diff --git a/tests/factories/message.py b/tests/factories/message.py index ecebe2f9..8871c664 100644 --- a/tests/factories/message.py +++ b/tests/factories/message.py @@ -318,4 +318,155 @@ def image_query( sender_id=sender_id, adapter=adapter, **overrides, + ) + + +def file_query( + url: str = "https://example.com/document.pdf", + name: str = "document.pdf", + text: str = "", + sender_id: typing.Union[int, str] = 12345, + **overrides, +) -> pipeline_query.Query: + """Create a file attachment query.""" + components = [] + if text: + components.append(platform_message.Plain(text=text)) + components.append(platform_message.File(url=url, name=name)) + chain = platform_message.MessageChain(components) + event = friend_message_event(chain, sender_id) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=sender_id, + sender_id=sender_id, + adapter=adapter, + **overrides, + ) + + +def unsupported_query( + unsupported_type: str = "CustomComponent", + text: str = "", + sender_id: typing.Union[int, str] = 12345, + **overrides, +) -> pipeline_query.Query: + """Create a query with unsupported/unknown message segment.""" + components = [] + if text: + components.append(platform_message.Plain(text=text)) + # Use Unknown component for unsupported types + components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}")) + chain = platform_message.MessageChain(components) + event = friend_message_event(chain, sender_id) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=sender_id, + sender_id=sender_id, + adapter=adapter, + **overrides, + ) + + +def query_with_session( + text: str = "hello", + sender_id: typing.Union[int, str] = 12345, + session: provider_session.Session = None, + **overrides, +) -> pipeline_query.Query: + """Create a query with a session object. + + If session is None, creates a default session with empty conversation. + """ + if session is None: + # Create a default session + session = provider_session.Session( + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=sender_id, + sender_id=sender_id, + use_prompt_name="default", + using_conversation=None, + conversations=[], + ) + + return text_query(text, sender_id, session=session, **overrides) + + +def query_with_config( + text: str = "hello", + sender_id: typing.Union[int, str] = 12345, + pipeline_config: dict = None, + **overrides, +) -> pipeline_query.Query: + """Create a query with custom pipeline configuration. + + If pipeline_config is None, uses default config. + Useful for testing specific stage behaviors. + """ + if pipeline_config is None: + pipeline_config = { + "ai": { + "runner": {"runner": "local-agent"}, + "local-agent": { + "model": {"primary": "test-model-uuid", "fallbacks": []}, + "prompt": "test-prompt", + }, + }, + "output": {"misc": {"at-sender": False, "quote-origin": False}}, + "trigger": {"misc": {"combine-quote-message": False}}, + } + + return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides) + + +def voice_query( + url: str = "https://example.com/audio.mp3", + sender_id: typing.Union[int, str] = 12345, + **overrides, +) -> pipeline_query.Query: + """Create a voice/audio query.""" + components = [ + platform_message.Voice(url=url), + ] + chain = platform_message.MessageChain(components) + event = friend_message_event(chain, sender_id) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=sender_id, + sender_id=sender_id, + adapter=adapter, + **overrides, + ) + + +def at_all_query( + text: str = "hello", + sender_id: typing.Union[int, str] = 12345, + group_id: typing.Union[int, str] = 99999, + **overrides, +) -> pipeline_query.Query: + """Create a group query with @All mention.""" + components = [ + platform_message.AtAll(), + platform_message.Plain(text=f" {text}"), + ] + chain = platform_message.MessageChain(components) + event = group_message_event(chain, sender_id, group_id=group_id) + adapter = mock_adapter() + return _base_query( + message_chain=chain, + message_event=event, + launcher_type=provider_session.LauncherTypes.GROUP, + launcher_id=group_id, + sender_id=sender_id, + adapter=adapter, + **overrides, ) \ No newline at end of file