From 6713b57d017b6edb0dad0391852017afd17bdfc8 Mon Sep 17 00:00:00 2001 From: fdc310 <2213070223@qq.com> Date: Mon, 11 May 2026 15:03:30 +0800 Subject: [PATCH] feat: enhance API key normalization and improve Space OAuth callback handling --- .../pkg/api/http/controller/groups/user.py | 1 + src/langbot/pkg/api/http/service/provider.py | 26 ++++++-- .../platform/sources/web_page_bot_adapter.py | 5 +- .../pkg/provider/modelmgr/requester.py | 1 + .../provider/modelmgr/requesters/chatcmpl.py | 1 - .../modelmgr/requesters/modelscopechatcmpl.py | 2 +- src/langbot/pkg/provider/modelmgr/token.py | 9 ++- .../unit_tests/provider/test_model_service.py | 64 +++++++++++++++++++ web/src/app/auth/space/callback/page.tsx | 58 ++++++++++++++++- 9 files changed, 153 insertions(+), 14 deletions(-) diff --git a/src/langbot/pkg/api/http/controller/groups/user.py b/src/langbot/pkg/api/http/controller/groups/user.py index ed5548f0..e86d6d1e 100644 --- a/src/langbot/pkg/api/http/controller/groups/user.py +++ b/src/langbot/pkg/api/http/controller/groups/user.py @@ -146,6 +146,7 @@ class UserRouterGroup(group.RouterGroup): return self.fail(3, str(e)) except ValueError as e: traceback.print_exc() + self.ap.logger.warning(f'Space OAuth callback failed: {e}') return self.fail(1, str(e)) except Exception as e: traceback.print_exc() diff --git a/src/langbot/pkg/api/http/service/provider.py b/src/langbot/pkg/api/http/service/provider.py index e15bd40c..598d72e8 100644 --- a/src/langbot/pkg/api/http/service/provider.py +++ b/src/langbot/pkg/api/http/service/provider.py @@ -18,9 +18,22 @@ class ModelProviderService: self.ap = ap @staticmethod - def _normalize_api_keys(api_key: str | None) -> list[str]: - normalized_api_key = api_key.strip() if api_key else '' - return [normalized_api_key] if normalized_api_key else [] + def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]: + if api_keys is None: + return [] + + raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys) + normalized_keys = [] + seen_keys = set() + + for raw_key in raw_keys: + normalized_key = raw_key.strip() if isinstance(raw_key, str) else '' + if not normalized_key or normalized_key in seen_keys: + continue + normalized_keys.append(normalized_key) + seen_keys.add(normalized_key) + + return normalized_keys async def get_providers(self) -> list[dict]: """Get all providers""" @@ -64,6 +77,7 @@ class ModelProviderService: async def create_provider(self, provider_data: dict) -> str: """Create a new provider""" provider_data['uuid'] = str(uuid.uuid4()) + provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys')) await self.ap.persistence_mgr.execute_async( sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data) ) @@ -77,6 +91,8 @@ class ModelProviderService: """Update an existing provider""" if 'uuid' in provider_data: del provider_data['uuid'] + if 'api_keys' in provider_data: + provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys')) await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_model.ModelProvider) .where(persistence_model.ModelProvider.uuid == provider_uuid) @@ -146,6 +162,8 @@ class ModelProviderService: async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str: """Find existing provider or create new one""" + api_keys = self._normalize_api_keys(api_keys) + # Try to find existing provider with same config result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_model.ModelProvider).where( @@ -173,7 +191,7 @@ class ModelProviderService: 'name': provider_name, 'requester': requester, 'base_url': base_url, - 'api_keys': api_keys or [], + 'api_keys': api_keys, } ) diff --git a/src/langbot/pkg/platform/sources/web_page_bot_adapter.py b/src/langbot/pkg/platform/sources/web_page_bot_adapter.py index 9b892a10..d424debd 100644 --- a/src/langbot/pkg/platform/sources/web_page_bot_adapter.py +++ b/src/langbot/pkg/platform/sources/web_page_bot_adapter.py @@ -27,10 +27,7 @@ class WebPageBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter listeners: dict = pydantic.Field(default_factory=dict, exclude=True) _ws_adapter: typing.Any = None - class Config: - arbitrary_types_allowed = True - # Allow private attributes - underscore_attrs_are_private = True + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs): super().__init__(config=config, logger=logger, **kwargs) diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index 08fee3ab..cb9a4183 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -340,6 +340,7 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): """Provider API请求器""" name: str = None + init_api_key: str = 'langbot-init-placeholder' ap: app.Application diff --git a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py index 89f75993..e63e362b 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -17,7 +17,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): """OpenAI ChatCompletion API 请求器""" client: openai.AsyncClient - init_api_key: str = 'langbot-init-placeholder' default_config: dict[str, typing.Any] = { 'base_url': 'https://api.openai.com/v1', diff --git a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index ed5d8795..c98a71d7 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -25,7 +25,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): async def initialize(self): self.client = openai.AsyncClient( - api_key='', + api_key=self.init_api_key, base_url=self.requester_cfg['base_url'], timeout=self.requester_cfg['timeout'], http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']), diff --git a/src/langbot/pkg/provider/modelmgr/token.py b/src/langbot/pkg/provider/modelmgr/token.py index e1a71614..51e97956 100644 --- a/src/langbot/pkg/provider/modelmgr/token.py +++ b/src/langbot/pkg/provider/modelmgr/token.py @@ -14,7 +14,14 @@ class TokenManager: def __init__(self, name: str, tokens: list[str]): self.name = name - self.tokens = tokens + self.tokens = [] + seen_tokens = set() + for token in tokens: + normalized_token = token.strip() if isinstance(token, str) else '' + if not normalized_token or normalized_token in seen_tokens: + continue + self.tokens.append(normalized_token) + seen_tokens.add(normalized_token) self.using_token_index = 0 def get_token(self) -> str: diff --git a/tests/unit_tests/provider/test_model_service.py b/tests/unit_tests/provider/test_model_service.py index b2ea7ba6..dab68043 100644 --- a/tests/unit_tests/provider/test_model_service.py +++ b/tests/unit_tests/provider/test_model_service.py @@ -17,6 +17,8 @@ from langbot.pkg.pipeline.preproc.preproc import PreProcessor from langbot.pkg.provider.modelmgr import requester from langbot.pkg.provider.modelmgr.modelmgr import ModelManager from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions +from langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl import ModelScopeChatCompletions +from langbot.pkg.provider.modelmgr.token import TokenManager from langbot.pkg.provider.runners.localagent import LocalAgentRunner @@ -66,6 +68,17 @@ def test_normalize_space_provider_api_keys_filters_blank_values(): assert ModelProviderService._normalize_api_keys('') == [] assert ModelProviderService._normalize_api_keys(' ') == [] assert ModelProviderService._normalize_api_keys(None) == [] + assert ModelProviderService._normalize_api_keys([' first-key ', '', 'first-key', 'second-key']) == [ + 'first-key', + 'second-key', + ] + + +def test_token_manager_filters_blank_and_duplicate_tokens(): + token_mgr = TokenManager('provider-uuid', [' first-key ', '', 'first-key', 'second-key', ' ']) + + assert token_mgr.tokens == ['first-key', 'second-key'] + assert token_mgr.get_token() == 'first-key' @pytest.mark.asyncio @@ -85,6 +98,57 @@ async def test_openai_requester_initialize_uses_placeholder_api_key(monkeypatch) assert captured_kwargs['api_key'] == OpenAIChatCompletions.init_api_key +@pytest.mark.asyncio +async def test_modelscope_requester_initialize_uses_placeholder_api_key(monkeypatch): + captured_kwargs = {} + + def fake_client(**kwargs): + captured_kwargs.update(kwargs) + return SimpleNamespace(**kwargs) + + monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.openai.AsyncClient', fake_client) + monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.httpx.AsyncClient', fake_client) + + requester_inst = ModelScopeChatCompletions(ap=SimpleNamespace(), config={}) + await requester_inst.initialize() + + assert captured_kwargs['api_key'] == ModelScopeChatCompletions.init_api_key + + +@pytest.mark.asyncio +async def test_openai_embedding_call_overrides_placeholder_api_key(): + captured_request = {} + + async def fake_create(**kwargs): + captured_request['api_key'] = fake_client.api_key + captured_request['kwargs'] = kwargs + return SimpleNamespace( + data=[SimpleNamespace(embedding=[0.1, 0.2])], + usage=SimpleNamespace(prompt_tokens=3, total_tokens=3), + ) + + fake_client = SimpleNamespace( + api_key=OpenAIChatCompletions.init_api_key, + embeddings=SimpleNamespace(create=fake_create), + ) + + requester_inst = OpenAIChatCompletions(ap=SimpleNamespace(), config={}) + requester_inst.client = fake_client + + embeddings, usage_info = await requester_inst.invoke_embedding( + model=requester.RuntimeEmbeddingModel( + model_entity=SimpleNamespace(name='text-embedding-3-small', extra_args={}), + provider=SimpleNamespace(token_mgr=TokenManager('provider-uuid', [' runtime-key ', '', 'runtime-key'])), + ), + input_text=['hello'], + ) + + assert captured_request['api_key'] == 'runtime-key' + assert captured_request['kwargs']['model'] == 'text-embedding-3-small' + assert embeddings == [[0.1, 0.2]] + assert usage_info == {'prompt_tokens': 3, 'total_tokens': 3} + + @pytest.mark.asyncio async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline(): from langbot.pkg.api.http.service.model import LLMModelsService diff --git a/web/src/app/auth/space/callback/page.tsx b/web/src/app/auth/space/callback/page.tsx index 2131a17c..8711cbd6 100644 --- a/web/src/app/auth/space/callback/page.tsx +++ b/web/src/app/auth/space/callback/page.tsx @@ -1,4 +1,4 @@ -import { useEffect, useState, useCallback, Suspense } from 'react'; +import { useEffect, useState, useCallback, Suspense, useRef } from 'react'; import { useNavigate, useSearchParams } from 'react-router-dom'; import { httpClient } from '@/app/infra/http/HttpClient'; import { toast } from 'sonner'; @@ -20,10 +20,39 @@ import { Button } from '@/components/ui/button'; import { LoadingSpinner } from '@/components/ui/loading-spinner'; import langbotIcon from '@/app/assets/langbot-logo.webp'; +type SpaceOAuthLoginResult = { + token: string; + user: string; +}; + +const pendingSpaceOAuthLogins = new Map< + string, + Promise +>(); + +function getOrCreateSpaceOAuthLoginPromise( + authCode: string, +): Promise { + const pendingRequest = pendingSpaceOAuthLogins.get(authCode); + if (pendingRequest) { + return pendingRequest; + } + + const requestPromise = httpClient + .exchangeSpaceOAuthCode(authCode) + .finally(() => { + pendingSpaceOAuthLogins.delete(authCode); + }); + + pendingSpaceOAuthLogins.set(authCode, requestPromise); + return requestPromise; +} + function SpaceOAuthCallbackContent() { const navigate = useNavigate(); const [searchParams] = useSearchParams(); const { t } = useTranslation(); + const isMountedRef = useRef(true); const [status, setStatus] = useState< 'loading' | 'confirm' | 'success' | 'error' @@ -37,7 +66,11 @@ function SpaceOAuthCallbackContent() { const handleOAuthCallback = useCallback( async (authCode: string) => { try { - const response = await httpClient.exchangeSpaceOAuthCode(authCode); + const response = await getOrCreateSpaceOAuthLoginPromise(authCode); + if (!isMountedRef.current) { + return; + } + localStorage.setItem('token', response.token); if (response.user) { localStorage.setItem('userEmail', response.user); @@ -52,6 +85,10 @@ function SpaceOAuthCallbackContent() { navigate(redirectTo); }, 1000); } catch (err) { + if (!isMountedRef.current) { + return; + } + setStatus('error'); const errorObj = err as { msg?: string }; const errMsg = (errorObj?.msg || '').toLowerCase(); @@ -72,6 +109,10 @@ function SpaceOAuthCallbackContent() { setIsProcessing(true); try { const response = await httpClient.bindSpaceAccount(authCode, state); + if (!isMountedRef.current) { + return; + } + localStorage.setItem('token', response.token); if (response.user) { localStorage.setItem('userEmail', response.user); @@ -82,6 +123,10 @@ function SpaceOAuthCallbackContent() { navigate('/home'); }, 1000); } catch (err) { + if (!isMountedRef.current) { + return; + } + setStatus('error'); const errorObj = err as { msg?: string }; const errMsg = (errorObj?.msg || '').toLowerCase(); @@ -91,13 +136,17 @@ function SpaceOAuthCallbackContent() { setErrorMessage(t('account.bindSpaceFailed')); } } finally { - setIsProcessing(false); + if (isMountedRef.current) { + setIsProcessing(false); + } } }, [navigate, t], ); useEffect(() => { + isMountedRef.current = true; + const authCode = searchParams.get('code'); const error = searchParams.get('error'); const errorDescription = searchParams.get('error_description'); @@ -135,6 +184,9 @@ function SpaceOAuthCallbackContent() { // Normal login/register mode handleOAuthCallback(authCode); } + return () => { + isMountedRef.current = false; + }; }, [searchParams, handleOAuthCallback, t]); const handleConfirmBind = () => {