Compare commits

..

1 Commits

Author SHA1 Message Date
huanghuoguoguo
8275cfd140 fix(api): avoid mutating bot update payload 2026-05-16 10:54:04 +08:00
4 changed files with 70 additions and 49 deletions

View File

@@ -52,9 +52,6 @@ class ApiKeyService:
async def verify_api_key(self, key: str) -> bool:
"""Verify if an API key is valid"""
if not isinstance(key, str) or not key.startswith('lbk_'):
return False
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(apikey.ApiKey).where(apikey.ApiKey.key == key)
)

View File

@@ -120,24 +120,26 @@ class BotService:
async def update_bot(self, bot_uuid: str, bot_data: dict) -> None:
"""Update bot"""
if 'uuid' in bot_data:
del bot_data['uuid']
update_data = bot_data.copy()
if 'uuid' in update_data:
del update_data['uuid']
# set use_pipeline_name
if 'use_pipeline_uuid' in bot_data:
if 'use_pipeline_uuid' in update_data:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']
persistence_pipeline.LegacyPipeline.uuid == update_data['use_pipeline_uuid']
)
)
pipeline = result.first()
if pipeline is not None:
bot_data['use_pipeline_name'] = pipeline.name
update_data['use_pipeline_name'] = pipeline.name
else:
raise Exception('Pipeline not found')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.update(persistence_bot.Bot).values(update_data).where(persistence_bot.Bot.uuid == bot_uuid)
)
await self.ap.platform_mgr.remove_bot(bot_uuid)

View File

@@ -0,0 +1,62 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from sqlalchemy.sql.dml import Update
from langbot.pkg.api.http.service.bot import BotService
class _FakeResult:
def __init__(self, value):
self.value = value
def first(self):
return self.value
class _PersistenceManager:
def __init__(self):
self.update_values = None
async def execute_async(self, statement):
if isinstance(statement, Update):
self.update_values = {
key: value for key, value in statement.compile().params.items() if not key.startswith('uuid_')
}
return None
return _FakeResult(SimpleNamespace(name='Updated Pipeline'))
async def test_update_bot_copies_input_before_filtering_and_setting_pipeline_name():
persistence_mgr = _PersistenceManager()
runtime_bot = SimpleNamespace(enable=False)
platform_mgr = SimpleNamespace(
remove_bot=AsyncMock(),
load_bot=AsyncMock(return_value=runtime_bot),
)
ap = SimpleNamespace(
persistence_mgr=persistence_mgr,
platform_mgr=platform_mgr,
sess_mgr=SimpleNamespace(session_list=[]),
)
service = BotService(ap)
service.get_bot = AsyncMock(return_value={'uuid': 'bot-1'})
payload = {
'uuid': 'caller-owned-uuid',
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
}
await service.update_bot('bot-1', payload)
assert payload == {
'uuid': 'caller-owned-uuid',
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
}
assert persistence_mgr.update_values == {
'name': 'Test Bot',
'use_pipeline_uuid': 'pipeline-1',
'use_pipeline_name': 'Updated Pipeline',
}

View File

@@ -1,40 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock
import pytest
from langbot.pkg.api.http.service.apikey import ApiKeyService
@pytest.mark.asyncio
@pytest.mark.parametrize('api_key', [None, 123, b'lbk_bytes', '', 'plain_key', ' LBK_bad', 'sk-lbk_fake'])
async def test_verify_api_key_rejects_non_lbk_keys_without_db_query(api_key):
persistence_mgr = SimpleNamespace(execute_async=AsyncMock())
service = ApiKeyService(SimpleNamespace(persistence_mgr=persistence_mgr))
result = await service.verify_api_key(api_key)
assert result is False
persistence_mgr.execute_async.assert_not_awaited()
@pytest.mark.asyncio
@pytest.mark.parametrize(
('db_row', 'expected'),
[
(object(), True),
(None, False),
],
)
async def test_verify_api_key_keeps_db_validation_for_lbk_keys(db_row, expected):
query_result = Mock()
query_result.first.return_value = db_row
persistence_mgr = SimpleNamespace(execute_async=AsyncMock(return_value=query_result))
service = ApiKeyService(SimpleNamespace(persistence_mgr=persistence_mgr))
result = await service.verify_api_key('lbk_valid_format')
assert result is expected
persistence_mgr.execute_async.assert_awaited_once()