Compare commits

..

1 Commits

Author SHA1 Message Date
huanghuoguoguo
8275cfd140 fix(api): avoid mutating bot update payload 2026-05-16 10:54:04 +08:00
4 changed files with 70 additions and 49 deletions

View File

@@ -120,24 +120,26 @@ class BotService:
async def update_bot(self, bot_uuid: str, bot_data: dict) -> None:
"""Update bot"""
if 'uuid' in bot_data:
del bot_data['uuid']
update_data = bot_data.copy()
if 'uuid' in update_data:
del update_data['uuid']
# set use_pipeline_name
if 'use_pipeline_uuid' in bot_data:
if 'use_pipeline_uuid' in update_data:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']
persistence_pipeline.LegacyPipeline.uuid == update_data['use_pipeline_uuid']
)
)
pipeline = result.first()
if pipeline is not None:
bot_data['use_pipeline_name'] = pipeline.name
update_data['use_pipeline_name'] = pipeline.name
else:
raise Exception('Pipeline not found')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.update(persistence_bot.Bot).values(update_data).where(persistence_bot.Bot.uuid == bot_uuid)
)
await self.ap.platform_mgr.remove_bot(bot_uuid)

View File

@@ -275,7 +275,6 @@ class MessageAggregator:
message_chain=merged_chain,
adapter=base_msg.adapter,
pipeline_uuid=base_msg.pipeline_uuid,
routed_by_rule=any(msg.routed_by_rule for msg in messages),
)
async def flush_all(self) -> None:

View File

@@ -0,0 +1,62 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from sqlalchemy.sql.dml import Update
from langbot.pkg.api.http.service.bot import BotService
class _FakeResult:
def __init__(self, value):
self.value = value
def first(self):
return self.value
class _PersistenceManager:
def __init__(self):
self.update_values = None
async def execute_async(self, statement):
if isinstance(statement, Update):
self.update_values = {
key: value for key, value in statement.compile().params.items() if not key.startswith('uuid_')
}
return None
return _FakeResult(SimpleNamespace(name='Updated Pipeline'))
async def test_update_bot_copies_input_before_filtering_and_setting_pipeline_name():
persistence_mgr = _PersistenceManager()
runtime_bot = SimpleNamespace(enable=False)
platform_mgr = SimpleNamespace(
remove_bot=AsyncMock(),
load_bot=AsyncMock(return_value=runtime_bot),
)
ap = SimpleNamespace(
persistence_mgr=persistence_mgr,
platform_mgr=platform_mgr,
sess_mgr=SimpleNamespace(session_list=[]),
)
service = BotService(ap)
service.get_bot = AsyncMock(return_value={'uuid': 'bot-1'})
payload = {
'uuid': 'caller-owned-uuid',
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
}
await service.update_bot('bot-1', payload)
assert payload == {
'uuid': 'caller-owned-uuid',
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
}
assert persistence_mgr.update_values == {
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
'use_pipeline_name': 'Updated Pipeline',
}

View File

@@ -1,42 +0,0 @@
"""
MessageAggregator unit tests.
"""
from importlib import import_module
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(sample_message_event, mock_adapter):
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
aggregator = import_module('langbot.pkg.pipeline.aggregator')
message_aggregator = aggregator.MessageAggregator(ap=None)
first_message = aggregator.PendingMessage(
bot_uuid='test-bot-uuid',
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
message_event=sample_message_event,
message_chain=platform_message.MessageChain([platform_message.Plain(text='first')]),
adapter=mock_adapter,
pipeline_uuid='test-pipeline-uuid',
routed_by_rule=False,
)
second_message = aggregator.PendingMessage(
bot_uuid='test-bot-uuid',
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=12345,
sender_id=12345,
message_event=sample_message_event,
message_chain=platform_message.MessageChain([platform_message.Plain(text='second')]),
adapter=mock_adapter,
pipeline_uuid='test-pipeline-uuid',
routed_by_rule=True,
)
merged_message = message_aggregator._merge_messages([first_message, second_message])
assert merged_message.routed_by_rule is True
assert str(merged_message.message_chain) == 'first\nsecond'