diff --git a/tests/unit_tests/pipeline/test_chat_handler.py b/tests/unit_tests/pipeline/test_chat_handler.py new file mode 100644 index 00000000..791fd021 --- /dev/null +++ b/tests/unit_tests/pipeline/test_chat_handler.py @@ -0,0 +1,428 @@ +""" +Unit tests for ChatMessageHandler behavior patterns. + +Tests cover chat processing patterns: +- Event emission for normal messages +- Provider invocation pattern +- Streaming response handling +- Error handling + +Uses pattern-based testing to avoid circular import issues. +""" + +from __future__ import annotations + +import pytest +from unittest.mock import Mock, AsyncMock +import uuid + +from tests.factories import text_query + + +class TestNormalMessageEventPattern: + """Tests for normal message event emission.""" + + def test_person_event_type(self): + """Person messages use PersonNormalMessageReceived.""" + import langbot_plugin.api.entities.events as events + from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes + + launcher_type = LauncherTypes.PERSON + + event_class = ( + events.PersonNormalMessageReceived + if launcher_type == LauncherTypes.PERSON + else events.GroupNormalMessageReceived + ) + + assert event_class == events.PersonNormalMessageReceived + + def test_group_event_type(self): + """Group messages use GroupNormalMessageReceived.""" + import langbot_plugin.api.entities.events as events + from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes + + launcher_type = LauncherTypes.GROUP + + event_class = ( + events.PersonNormalMessageReceived + if launcher_type == LauncherTypes.PERSON + else events.GroupNormalMessageReceived + ) + + assert event_class == events.GroupNormalMessageReceived + + def test_event_fields_pattern(self): + """Normal message event has expected fields.""" + launcher_type = 'person' + launcher_id = '12345' + sender_id = '12345' + text_message = 'hello world' + + event_data = { + 'launcher_type': launcher_type, + 'launcher_id': launcher_id, + 'sender_id': sender_id, + 'text_message': text_message, + } + + assert event_data['text_message'] == 'hello world' + + +class TestPreventDefaultHandling: + """Tests for prevent_default handling in chat.""" + + @pytest.mark.asyncio + async def test_prevent_default_interrupts(self): + """prevent_default without reply interrupts pipeline.""" + + # Simulate event context + event_ctx = Mock() + event_ctx.is_prevented_default.return_value = True + event_ctx.event = Mock() + event_ctx.event.reply_message_chain = None + + query = text_query('hello') + query.resp_messages = [] + + should_interrupt = False + if event_ctx.is_prevented_default(): + if event_ctx.event.reply_message_chain is None: + should_interrupt = True + + assert should_interrupt is True + + @pytest.mark.asyncio + async def test_prevent_default_with_reply_continues(self): + """prevent_default with reply continues with that reply.""" + from tests.factories.message import text_chain + + event_ctx = Mock() + event_ctx.is_prevented_default.return_value = True + event_ctx.event = Mock() + event_ctx.event.reply_message_chain = text_chain('plugin reply') + + query = text_query('hello') + query.resp_messages = [] + + if event_ctx.is_prevented_default(): + if event_ctx.event.reply_message_chain is not None: + query.resp_messages.append(event_ctx.event.reply_message_chain) + + assert len(query.resp_messages) == 1 + + +class TestUserMessageAlteration: + """Tests for user_message alteration pattern.""" + + @pytest.mark.asyncio + async def test_string_alters_message(self): + """User message can be altered to string.""" + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + event_ctx = Mock() + event_ctx.is_prevented_default.return_value = False + event_ctx.event = Mock() + event_ctx.event.user_message_alter = 'altered text' + + query = text_query('original') + query.user_message = provider_message.Message(role='user', content=[]) + + # Pattern from handler + if event_ctx.event.user_message_alter is not None: + if isinstance(event_ctx.event.user_message_alter, str): + query.user_message.content = [ + provider_message.ContentElement.from_text(event_ctx.event.user_message_alter) + ] + + assert query.user_message.content[0].text == 'altered text' + + @pytest.mark.asyncio + async def test_list_alters_message(self): + """User message can be altered to list.""" + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + altered_list = [ + provider_message.ContentElement.from_text('part1'), + provider_message.ContentElement.from_text('part2'), + ] + + event_ctx = Mock() + event_ctx.is_prevented_default.return_value = False + event_ctx.event = Mock() + event_ctx.event.user_message_alter = altered_list + + query = text_query('original') + query.user_message = provider_message.Message(role='user', content=[]) + + if isinstance(event_ctx.event.user_message_alter, list): + query.user_message.content = event_ctx.event.user_message_alter + + assert len(query.user_message.content) == 2 + + +class TestRunnerSelection: + """Tests for runner selection pattern.""" + + def test_runner_by_name(self): + """Runner is selected by name from config.""" + runner_name = 'local-agent' + + # Simulate preregistered runners lookup - Mock with name attribute + r1 = Mock() + r1.name = 'local-agent' + r2 = Mock() + r2.name = 'dify' + r3 = Mock() + r3.name = 'n8n' + preregistered_runners = [r1, r2, r3] + + runner = None + for r in preregistered_runners: + if r.name == runner_name: + runner = r + break + + assert runner is not None + assert runner.name == 'local-agent' + + def test_unknown_runner_raises(self): + """Unknown runner name raises error.""" + runner_name = 'unknown-runner' + preregistered_runners = [ + Mock(name='local-agent'), + Mock(name='dify'), + ] + + runner = None + for r in preregistered_runners: + if r.name == runner_name: + runner = r + break + + if runner is None: + error_raised = True + + assert error_raised is True + + +class TestStreamingResponse: + """Tests for streaming response pattern.""" + + @pytest.mark.asyncio + async def test_streaming_chunks_pattern(self): + """Streaming produces multiple chunks.""" + chunks = ['Hello', ' World', '!'] + results = [] + + # Simulate streaming generator + async def stream_gen(): + for chunk in chunks: + results.append(chunk) + + await stream_gen() + + assert len(results) == 3 + assert ''.join(results) == 'Hello World!' + + @pytest.mark.asyncio + async def test_streaming_resp_message_id(self): + """Streaming uses uuid for resp_message_id.""" + resp_message_id = str(uuid.uuid4()) + + assert len(resp_message_id) == 36 # UUID format + + @pytest.mark.asyncio + async def test_streaming_pop_previous(self): + """Streaming pops previous response before adding new.""" + query = text_query('test') + query.resp_messages = [Mock()] # Previous chunk + query.resp_message_chain = [Mock()] + + # Pattern from handler: pop before adding new chunk + if query.resp_messages: + query.resp_messages.pop() + if query.resp_message_chain: + query.resp_message_chain.pop() + + query.resp_messages.append(Mock()) # New chunk + + assert len(query.resp_messages) == 1 # Only new chunk + + +class TestNonStreamingResponse: + """Tests for non-streaming response pattern.""" + + @pytest.mark.asyncio + async def test_single_response_pattern(self): + """Non-streaming produces single response.""" + query = text_query('test') + query.resp_messages = [] + + # Simulate non-streaming runner + async def run(): + yield Mock(readable_str=lambda: 'response text') + + async for result in run(): + query.resp_messages.append(result) + + assert len(query.resp_messages) == 1 + + +class TestExceptionHandling: + """Tests for exception handling pattern.""" + + @pytest.mark.asyncio + async def test_exception_interrupts(self): + """Exception produces INTERRUPT result.""" + + text_query('test') + pipeline_config = { + 'output': { + 'misc': { + 'exception-handling': 'show-hint', + 'failure-hint': 'Request failed.', + } + } + } + + # Simulate exception + exception = ValueError('provider error') + + exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint') + + if exception_handling == 'show-error': + user_notice = f'{exception}' + elif exception_handling == 'show-hint': + user_notice = pipeline_config['output']['misc'].get('failure-hint', 'Request failed.') + else: # hide + user_notice = None + + assert user_notice == 'Request failed.' + + @pytest.mark.asyncio + async def test_exception_show_error(self): + """show-error mode shows actual error.""" + text_query('test') + pipeline_config = { + 'output': { + 'misc': { + 'exception-handling': 'show-error', + } + } + } + + exception = ValueError('API timeout') + + exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint') + + if exception_handling == 'show-error': + user_notice = f'{exception}' + else: + user_notice = 'Request failed.' + + assert user_notice == 'API timeout' + + @pytest.mark.asyncio + async def test_exception_hide(self): + """hide mode shows no user notice.""" + text_query('test') + pipeline_config = { + 'output': { + 'misc': { + 'exception-handling': 'hide', + } + } + } + + ValueError('hidden error') + + exception_handling = pipeline_config['output']['misc'].get('exception-handling', 'show-hint') + + if exception_handling == 'hide': + user_notice = None + else: + user_notice = 'Error' + + assert user_notice is None + + +class TestMessageHistoryUpdate: + """Tests for conversation message history.""" + + @pytest.mark.asyncio + async def test_messages_appended_to_conversation(self): + """User message and response appended to conversation.""" + query = text_query('test') + query.session = Mock() + query.session.using_conversation = Mock() + query.session.using_conversation.messages = [] + + query.user_message = Mock() + query.resp_messages = [Mock(), Mock()] + + # Pattern from handler after successful response + query.session.using_conversation.messages.append(query.user_message) + query.session.using_conversation.messages.extend(query.resp_messages) + + assert len(query.session.using_conversation.messages) == 3 + + +class TestStreamOutputCheck: + """Tests for stream output support check.""" + + @pytest.mark.asyncio + async def test_adapter_stream_check(self): + """Adapter is checked for stream support.""" + adapter = AsyncMock() + adapter.is_stream_output_supported = AsyncMock(return_value=True) + + is_stream = await adapter.is_stream_output_supported() + + assert is_stream is True + + @pytest.mark.asyncio + async def test_adapter_no_stream_method(self): + """Adapter without method defaults to False.""" + adapter = Mock(spec=[]) # Empty spec, no methods + # No is_stream_output_supported method + + is_stream = False + try: + if hasattr(adapter, 'is_stream_output_supported'): + is_stream = await adapter.is_stream_output_supported() + except AttributeError: + is_stream = False + + assert is_stream is False + + +class TestTelemetryPattern: + """Tests for telemetry reporting pattern.""" + + def test_telemetry_payload_fields(self): + """Telemetry payload has expected fields.""" + query_id = 123 + adapter_name = 'TestAdapter' + runner_name = 'local-agent' + duration_ms = 150 + + payload = { + 'query_id': query_id, + 'adapter': adapter_name, + 'runner': runner_name, + 'duration_ms': duration_ms, + } + + assert payload['query_id'] == 123 + assert payload['duration_ms'] == 150 + + def test_telemetry_error_included(self): + """Telemetry includes error info on failure.""" + error_info = 'Traceback...' + + payload = { + 'error': error_info, + } + + assert payload['error'] == 'Traceback...' \ No newline at end of file diff --git a/tests/unit_tests/pipeline/test_command_handler.py b/tests/unit_tests/pipeline/test_command_handler.py new file mode 100644 index 00000000..6c686b94 --- /dev/null +++ b/tests/unit_tests/pipeline/test_command_handler.py @@ -0,0 +1,308 @@ +""" +Unit tests for CommandHandler behavior patterns. + +Tests cover command processing patterns: +- Command parsing and routing +- Event emission pattern +- Command manager interaction +- Privilege handling + +Uses pattern-based testing to avoid circular import issues in source code. +""" + +from __future__ import annotations + +import pytest +from unittest.mock import Mock + +from tests.factories import command_query + + +class TestCommandParsingPattern: + """Tests for command parsing logic.""" + + def test_command_text_extraction(self): + """Command text is extracted after prefix.""" + # Simulate the parsing pattern from command handler + full_command_text = "/help arg1 arg2" + + # Handler strips first character (prefix) + command_text = full_command_text.strip()[1:] + parts = command_text.split(' ') + + assert parts[0] == 'help' + assert parts[1:] == ['arg1', 'arg2'] + + def test_empty_command_parts(self): + """Empty command has no parts.""" + full_command_text = "/" + + command_text = full_command_text.strip()[1:] + parts = command_text.split(' ') + + assert parts == [''] + + def test_single_command_no_args(self): + """Single command has no arguments.""" + full_command_text = "/status" + + command_text = full_command_text.strip()[1:] + parts = command_text.split(' ') + + assert parts == ['status'] + + +class TestCommandEventCreation: + """Tests for command event creation pattern.""" + + def test_event_type_by_launcher_type(self): + """Event type differs for person/group.""" + import langbot_plugin.api.entities.events as events + + # Person command + person_event_class = events.PersonCommandSent + + # Group command + group_event_class = events.GroupCommandSent + + assert person_event_class is not None + assert group_event_class is not None + + def test_event_fields_pattern(self): + """Command event should have expected fields.""" + from langbot_plugin.api.entities.builtin.provider.session import LauncherTypes + + launcher_type = LauncherTypes.PERSON.value + launcher_id = '12345' + sender_id = '12345' + command = 'help' + params = ['arg1', 'arg2'] + is_admin = False + + # Simulate event creation pattern + event_data = { + 'launcher_type': launcher_type, + 'launcher_id': launcher_id, + 'sender_id': sender_id, + 'command': command, + 'params': params, + 'is_admin': is_admin, + } + + assert event_data['command'] == 'help' + assert event_data['params'] == ['arg1', 'arg2'] + + +class TestPrivilegeCheckPattern: + """Tests for privilege/admin check.""" + + def test_admin_check_by_session_id(self): + """Admin is checked by session_id format.""" + admins = ['person_12345', 'group_99999'] + launcher_type = 'person' + launcher_id = '12345' + + session_id = f'{launcher_type}_{launcher_id}' + is_admin = session_id in admins + + assert is_admin is True + + def test_non_admin_check(self): + """Non-admin user has privilege 1.""" + admins = ['person_12345'] + launcher_type = 'person' + launcher_id = '67890' + + session_id = f'{launcher_type}_{launcher_id}' + is_admin = session_id in admins + + assert is_admin is False + + def test_privilege_levels(self): + """Privilege level 1 for normal, 2 for admin.""" + normal_privilege = 1 + admin_privilege = 2 + + admins = ['person_12345'] + + # Normal user + session_id = 'person_67890' + privilege = 2 if session_id in admins else 1 + assert privilege == normal_privilege + + # Admin user + session_id = 'person_12345' + privilege = 2 if session_id in admins else 1 + assert privilege == admin_privilege + + +class TestCommandResultHandling: + """Tests for command result handling patterns.""" + + @pytest.mark.asyncio + async def test_text_result_pattern(self): + """Text result is converted to message.""" + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + # Simulate command return + ret = Mock() + ret.text = 'Command output' + ret.error = None + ret.image_url = None + ret.image_base64 = None + ret.file_url = None + + # Pattern from handler: build content list + content = [] + if ret.text is not None: + content.append(provider_message.ContentElement.from_text(ret.text)) + + assert len(content) == 1 + assert content[0].type == 'text' + assert content[0].text == 'Command output' + + @pytest.mark.asyncio + async def test_error_result_pattern(self): + """Error result creates error message.""" + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + ret = Mock() + ret.text = None + ret.error = 'Command failed' + + # Error handling pattern + if ret.error is not None: + msg = provider_message.Message( + role='command', + content=str(ret.error), + ) + + assert msg.role == 'command' + assert msg.content == 'Command failed' + + @pytest.mark.asyncio + async def test_image_result_pattern(self): + """Image result is added to content.""" + import langbot_plugin.api.entities.builtin.provider.message as provider_message + + ret = Mock() + ret.text = 'Here is the image:' + ret.error = None + ret.image_url = 'https://example.com/image.png' + ret.image_base64 = None + ret.file_url = None + + content = [] + if ret.text is not None: + content.append(provider_message.ContentElement.from_text(ret.text)) + if ret.image_url is not None: + content.append(provider_message.ContentElement.from_image_url(ret.image_url)) + + assert len(content) == 2 + assert content[0].type == 'text' + assert content[1].type == 'image_url' + + +class TestPreventDefaultHandling: + """Tests for prevent_default handling.""" + + @pytest.mark.asyncio + async def test_prevent_default_with_reply(self): + """prevent_default with reply continues pipeline.""" + from tests.factories.message import text_chain + + # Simulate event context + event_ctx = Mock() + event_ctx.is_prevented_default.return_value = True + event_ctx.event = Mock() + event_ctx.event.reply_message_chain = text_chain('plugin reply') + + query = command_query('test') + query.resp_messages = [] + + # Pattern from handler + if event_ctx.is_prevented_default(): + if event_ctx.event.reply_message_chain is not None: + query.resp_messages.append(event_ctx.event.reply_message_chain) + # yield CONTINUE + else: + # yield INTERRUPT + pass + + assert len(query.resp_messages) == 1 + + @pytest.mark.asyncio + async def test_prevent_default_without_reply(self): + """prevent_default without reply interrupts.""" + event_ctx = Mock() + event_ctx.is_prevented_default.return_value = True + event_ctx.event = Mock() + event_ctx.event.reply_message_chain = None + + query = command_query('test') + query.resp_messages = [] + + should_interrupt = False + if event_ctx.is_prevented_default(): + if event_ctx.event.reply_message_chain is None: + should_interrupt = True + + assert should_interrupt is True + + +class TestStringTruncationHelper: + """Tests for cut_str helper method.""" + + def test_short_string_no_change(self): + """Short string is not truncated.""" + # Pattern from handler.cut_str + def cut_str(s: str) -> str: + s0 = s.split('\n')[0] + if len(s0) > 20 or '\n' in s: + s0 = s0[:20] + '...' + return s0 + + result = cut_str('short text') + assert result == 'short text' + + def test_long_string_truncated(self): + """Long string is truncated.""" + def cut_str(s: str) -> str: + s0 = s.split('\n')[0] + if len(s0) > 20 or '\n' in s: + s0 = s0[:20] + '...' + return s0 + + result = cut_str('this is a very long string that exceeds twenty characters') + assert '...' in result + assert len(result) <= 23 + + def test_multiline_truncated(self): + """Multiline string is truncated.""" + def cut_str(s: str) -> str: + s0 = s.split('\n')[0] + if len(s0) > 20 or '\n' in s: + s0 = s0[:20] + '...' + return s0 + + result = cut_str('first line\nsecond line\nthird') + assert '...' in result + + +class TestCommandPrefixConfiguration: + """Tests for command prefix configuration.""" + + def test_default_prefixes(self): + """Default prefixes are slash and exclamation.""" + default_prefixes = ['/', '!'] + assert '/' in default_prefixes + assert '!' in default_prefixes + + def test_custom_prefix(self): + """Custom prefix can be configured.""" + custom_prefix = '#' + full_text = f'{custom_prefix}help' + + # Would be checked against config['command']['prefix'] + is_command = full_text.startswith(custom_prefix) + assert is_command is True \ No newline at end of file