fix: tests

This commit is contained in:
Junyan Qin
2025-11-16 18:39:45 +08:00
parent e7885539a7
commit d26e81620d
6 changed files with 47 additions and 146 deletions

View File

@@ -26,7 +26,7 @@ markers =
# Coverage options (when using pytest-cov) # Coverage options (when using pytest-cov)
[coverage:run] [coverage:run]
source = pkg source = langbot.pkg
omit = omit =
*/tests/* */tests/*
*/test_*.py */test_*.py

View File

@@ -5,7 +5,6 @@ Tests the actual BanSessionCheckStage implementation from pkg.pipeline.bansess
""" """
import pytest import pytest
from unittest.mock import Mock
from importlib import import_module from importlib import import_module
import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.provider.session as provider_session
@@ -13,9 +12,8 @@ import langbot_plugin.api.entities.builtin.provider.session as provider_session
def get_modules(): def get_modules():
"""Lazy import to ensure proper initialization order""" """Lazy import to ensure proper initialization order"""
# Import pipelinemgr first to trigger proper stage registration # Import pipelinemgr first to trigger proper stage registration
pipelinemgr = import_module('pkg.pipeline.pipelinemgr') bansess = import_module('langbot.pkg.pipeline.bansess.bansess')
bansess = import_module('pkg.pipeline.bansess.bansess') entities = import_module('langbot.pkg.pipeline.entities')
entities = import_module('pkg.pipeline.entities')
return bansess, entities return bansess, entities
@@ -26,14 +24,7 @@ async def test_whitelist_allow(mock_app, sample_query):
sample_query.launcher_type = provider_session.LauncherTypes.PERSON sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345' sample_query.launcher_id = '12345'
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'access-control': {'mode': 'whitelist', 'whitelist': ['person_12345']}}}
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['person_12345']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app) stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)
@@ -51,14 +42,7 @@ async def test_whitelist_deny(mock_app, sample_query):
sample_query.launcher_type = provider_session.LauncherTypes.PERSON sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '99999' sample_query.launcher_id = '99999'
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'access-control': {'mode': 'whitelist', 'whitelist': ['person_12345']}}}
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['person_12345']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app) stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)
@@ -75,14 +59,7 @@ async def test_blacklist_allow(mock_app, sample_query):
sample_query.launcher_type = provider_session.LauncherTypes.PERSON sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345' sample_query.launcher_id = '12345'
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'access-control': {'mode': 'blacklist', 'blacklist': ['person_99999']}}}
'trigger': {
'access-control': {
'mode': 'blacklist',
'blacklist': ['person_99999']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app) stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)
@@ -99,14 +76,7 @@ async def test_blacklist_deny(mock_app, sample_query):
sample_query.launcher_type = provider_session.LauncherTypes.PERSON sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345' sample_query.launcher_id = '12345'
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'access-control': {'mode': 'blacklist', 'blacklist': ['person_12345']}}}
'trigger': {
'access-control': {
'mode': 'blacklist',
'blacklist': ['person_12345']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app) stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)
@@ -123,14 +93,7 @@ async def test_wildcard_group(mock_app, sample_query):
sample_query.launcher_type = provider_session.LauncherTypes.GROUP sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.launcher_id = '12345' sample_query.launcher_id = '12345'
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'access-control': {'mode': 'whitelist', 'whitelist': ['group_*']}}}
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['group_*']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app) stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)
@@ -147,14 +110,7 @@ async def test_wildcard_person(mock_app, sample_query):
sample_query.launcher_type = provider_session.LauncherTypes.PERSON sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345' sample_query.launcher_id = '12345'
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'access-control': {'mode': 'whitelist', 'whitelist': ['person_*']}}}
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['person_*']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app) stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)
@@ -172,14 +128,7 @@ async def test_user_id_wildcard(mock_app, sample_query):
sample_query.launcher_type = provider_session.LauncherTypes.PERSON sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.launcher_id = '12345' sample_query.launcher_id = '12345'
sample_query.sender_id = '67890' sample_query.sender_id = '67890'
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'access-control': {'mode': 'whitelist', 'whitelist': ['*_67890']}}}
'trigger': {
'access-control': {
'mode': 'whitelist',
'whitelist': ['*_67890']
}
}
}
stage = bansess.BanSessionCheckStage(mock_app) stage = bansess.BanSessionCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)

View File

@@ -8,19 +8,19 @@ from importlib import import_module
def get_pipelinemgr_module(): def get_pipelinemgr_module():
return import_module('pkg.pipeline.pipelinemgr') return import_module('langbot.pkg.pipeline.pipelinemgr')
def get_stage_module(): def get_stage_module():
return import_module('pkg.pipeline.stage') return import_module('langbot.pkg.pipeline.stage')
def get_entities_module(): def get_entities_module():
return import_module('pkg.pipeline.entities') return import_module('langbot.pkg.pipeline.entities')
def get_persistence_pipeline_module(): def get_persistence_pipeline_module():
return import_module('pkg.entity.persistence.pipeline') return import_module('langbot.pkg.entity.persistence.pipeline')
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -13,10 +13,9 @@ import langbot_plugin.api.entities.builtin.provider.session as provider_session
def get_modules(): def get_modules():
"""Lazy import to ensure proper initialization order""" """Lazy import to ensure proper initialization order"""
# Import pipelinemgr first to trigger proper stage registration # Import pipelinemgr first to trigger proper stage registration
pipelinemgr = import_module('pkg.pipeline.pipelinemgr') ratelimit = import_module('langbot.pkg.pipeline.ratelimit.ratelimit')
ratelimit = import_module('pkg.pipeline.ratelimit.ratelimit') entities = import_module('langbot.pkg.pipeline.entities')
entities = import_module('pkg.pipeline.entities') algo_module = import_module('langbot.pkg.pipeline.ratelimit.algo')
algo_module = import_module('pkg.pipeline.ratelimit.algo')
return ratelimit, entities, algo_module return ratelimit, entities, algo_module
@@ -44,11 +43,7 @@ async def test_require_access_allowed(mock_app, sample_query):
assert result.result_type == entities.ResultType.CONTINUE assert result.result_type == entities.ResultType.CONTINUE
assert result.new_query == sample_query assert result.new_query == sample_query
mock_algo.require_access.assert_called_once_with( mock_algo.require_access.assert_called_once_with(sample_query, 'person', '12345')
sample_query,
'person',
'12345'
)
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -102,8 +97,4 @@ async def test_release_access(mock_app, sample_query):
assert result.result_type == entities.ResultType.CONTINUE assert result.result_type == entities.ResultType.CONTINUE
assert result.new_query == sample_query assert result.new_query == sample_query
mock_algo.release_access.assert_called_once_with( mock_algo.release_access.assert_called_once_with(sample_query, 'person', '12345')
sample_query,
'person',
'12345'
)

View File

@@ -14,11 +14,11 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
def get_modules(): def get_modules():
"""Lazy import to ensure proper initialization order""" """Lazy import to ensure proper initialization order"""
# Import pipelinemgr first to trigger proper stage registration # Import pipelinemgr first to trigger proper stage registration
pipelinemgr = import_module('pkg.pipeline.pipelinemgr') # pipelinemgr = import_module('langbot.pkg.pipeline.pipelinemgr')
resprule = import_module('pkg.pipeline.resprule.resprule') resprule = import_module('langbot.pkg.pipeline.resprule.resprule')
entities = import_module('pkg.pipeline.entities') entities = import_module('langbot.pkg.pipeline.entities')
rule = import_module('pkg.pipeline.resprule.rule') rule = import_module('langbot.pkg.pipeline.resprule.rule')
rule_entities = import_module('pkg.pipeline.resprule.entities') rule_entities = import_module('langbot.pkg.pipeline.resprule.entities')
return resprule, entities, rule, rule_entities return resprule, entities, rule, rule_entities
@@ -28,11 +28,7 @@ async def test_person_message_skip(mock_app, sample_query):
resprule, entities, rule, rule_entities = get_modules() resprule, entities, rule, rule_entities = get_modules()
sample_query.launcher_type = provider_session.LauncherTypes.PERSON sample_query.launcher_type = provider_session.LauncherTypes.PERSON
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'group-respond-rules': {}}}
'trigger': {
'group-respond-rules': {}
}
}
stage = resprule.GroupRespondRuleCheckStage(mock_app) stage = resprule.GroupRespondRuleCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)
@@ -50,18 +46,13 @@ async def test_group_message_no_match(mock_app, sample_query):
sample_query.launcher_type = provider_session.LauncherTypes.GROUP sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.launcher_id = '12345' sample_query.launcher_id = '12345'
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'group-respond-rules': {}}}
'trigger': {
'group-respond-rules': {}
}
}
# Create mock rule matcher that doesn't match # Create mock rule matcher that doesn't match
mock_rule = Mock(spec=rule.GroupRespondRule) mock_rule = Mock(spec=rule.GroupRespondRule)
mock_rule.match = AsyncMock(return_value=rule_entities.RuleJudgeResult( mock_rule.match = AsyncMock(
matching=False, return_value=rule_entities.RuleJudgeResult(matching=False, replacement=sample_query.message_chain)
replacement=sample_query.message_chain )
))
stage = resprule.GroupRespondRuleCheckStage(mock_app) stage = resprule.GroupRespondRuleCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)
@@ -81,23 +72,14 @@ async def test_group_message_match(mock_app, sample_query):
sample_query.launcher_type = provider_session.LauncherTypes.GROUP sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.launcher_id = '12345' sample_query.launcher_id = '12345'
sample_query.pipeline_config = { sample_query.pipeline_config = {'trigger': {'group-respond-rules': {}}}
'trigger': {
'group-respond-rules': {}
}
}
# Create new message chain after rule processing # Create new message chain after rule processing
new_chain = platform_message.MessageChain([ new_chain = platform_message.MessageChain([platform_message.Plain(text='Processed message')])
platform_message.Plain(text='Processed message')
])
# Create mock rule matcher that matches # Create mock rule matcher that matches
mock_rule = Mock(spec=rule.GroupRespondRule) mock_rule = Mock(spec=rule.GroupRespondRule)
mock_rule.match = AsyncMock(return_value=rule_entities.RuleJudgeResult( mock_rule.match = AsyncMock(return_value=rule_entities.RuleJudgeResult(matching=True, replacement=new_chain))
matching=True,
replacement=new_chain
))
stage = resprule.GroupRespondRuleCheckStage(mock_app) stage = resprule.GroupRespondRuleCheckStage(mock_app)
await stage.initialize(sample_query.pipeline_config) await stage.initialize(sample_query.pipeline_config)
@@ -115,27 +97,21 @@ async def test_group_message_match(mock_app, sample_query):
async def test_atbot_rule_match(mock_app, sample_query): async def test_atbot_rule_match(mock_app, sample_query):
"""Test AtBotRule removes At component""" """Test AtBotRule removes At component"""
resprule, entities, rule, rule_entities = get_modules() resprule, entities, rule, rule_entities = get_modules()
atbot_module = import_module('pkg.pipeline.resprule.rules.atbot') atbot_module = import_module('langbot.pkg.pipeline.resprule.rules.atbot')
sample_query.launcher_type = provider_session.LauncherTypes.GROUP sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.adapter.bot_account_id = '999' sample_query.adapter.bot_account_id = '999'
# Create message chain with At component # Create message chain with At component
message_chain = platform_message.MessageChain([ message_chain = platform_message.MessageChain(
platform_message.At(target='999'), [platform_message.At(target='999'), platform_message.Plain(text='Hello bot')]
platform_message.Plain(text='Hello bot') )
])
sample_query.message_chain = message_chain sample_query.message_chain = message_chain
atbot_rule = atbot_module.AtBotRule(mock_app) atbot_rule = atbot_module.AtBotRule(mock_app)
await atbot_rule.initialize() await atbot_rule.initialize()
result = await atbot_rule.match( result = await atbot_rule.match(str(message_chain), message_chain, {}, sample_query)
str(message_chain),
message_chain,
{},
sample_query
)
assert result.matching is True assert result.matching is True
# At component should be removed # At component should be removed
@@ -147,25 +123,18 @@ async def test_atbot_rule_match(mock_app, sample_query):
async def test_atbot_rule_no_match(mock_app, sample_query): async def test_atbot_rule_no_match(mock_app, sample_query):
"""Test AtBotRule when no At component present""" """Test AtBotRule when no At component present"""
resprule, entities, rule, rule_entities = get_modules() resprule, entities, rule, rule_entities = get_modules()
atbot_module = import_module('pkg.pipeline.resprule.rules.atbot') atbot_module = import_module('langbot.pkg.pipeline.resprule.rules.atbot')
sample_query.launcher_type = provider_session.LauncherTypes.GROUP sample_query.launcher_type = provider_session.LauncherTypes.GROUP
sample_query.adapter.bot_account_id = '999' sample_query.adapter.bot_account_id = '999'
# Create message chain without At component # Create message chain without At component
message_chain = platform_message.MessageChain([ message_chain = platform_message.MessageChain([platform_message.Plain(text='Hello')])
platform_message.Plain(text='Hello')
])
sample_query.message_chain = message_chain sample_query.message_chain = message_chain
atbot_rule = atbot_module.AtBotRule(mock_app) atbot_rule = atbot_module.AtBotRule(mock_app)
await atbot_rule.initialize() await atbot_rule.initialize()
result = await atbot_rule.match( result = await atbot_rule.match(str(message_chain), message_chain, {}, sample_query)
str(message_chain),
message_chain,
{},
sample_query
)
assert result.matching is False assert result.matching is False

View File

@@ -4,9 +4,9 @@ Tests for storage manager and provider selection
import pytest import pytest
from unittest.mock import Mock, AsyncMock, patch from unittest.mock import Mock, AsyncMock, patch
from pkg.storage.mgr import StorageMgr from langbot.pkg.storage.mgr import StorageMgr
from pkg.storage.providers.localstorage import LocalStorageProvider from langbot.pkg.storage.providers.localstorage import LocalStorageProvider
from pkg.storage.providers.s3storage import S3StorageProvider from langbot.pkg.storage.providers.s3storage import S3StorageProvider
class TestStorageProviderSelection: class TestStorageProviderSelection:
@@ -34,11 +34,7 @@ class TestStorageProviderSelection:
# Mock application # Mock application
mock_app = Mock() mock_app = Mock()
mock_app.instance_config = Mock() mock_app.instance_config = Mock()
mock_app.instance_config.data = { mock_app.instance_config.data = {'storage': {'use': 'local'}}
'storage': {
'use': 'local'
}
}
mock_app.logger = Mock() mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app) storage_mgr = StorageMgr(mock_app)
@@ -62,8 +58,8 @@ class TestStorageProviderSelection:
'access_key_id': 'test_key', 'access_key_id': 'test_key',
'secret_access_key': 'test_secret', 'secret_access_key': 'test_secret',
'region': 'us-east-1', 'region': 'us-east-1',
'bucket': 'test-bucket' 'bucket': 'test-bucket',
} },
} }
} }
mock_app.logger = Mock() mock_app.logger = Mock()
@@ -81,11 +77,7 @@ class TestStorageProviderSelection:
# Mock application # Mock application
mock_app = Mock() mock_app = Mock()
mock_app.instance_config = Mock() mock_app.instance_config = Mock()
mock_app.instance_config.data = { mock_app.instance_config.data = {'storage': {'use': 'invalid_type'}}
'storage': {
'use': 'invalid_type'
}
}
mock_app.logger = Mock() mock_app.logger = Mock()
storage_mgr = StorageMgr(mock_app) storage_mgr = StorageMgr(mock_app)