feat: enhance API key normalization and improve Space OAuth callback handling

This commit is contained in:
fdc310
2026-05-11 15:03:30 +08:00
parent ea13ef87f2
commit 6713b57d01
9 changed files with 153 additions and 14 deletions

View File

@@ -146,6 +146,7 @@ class UserRouterGroup(group.RouterGroup):
return self.fail(3, str(e))
except ValueError as e:
traceback.print_exc()
self.ap.logger.warning(f'Space OAuth callback failed: {e}')
return self.fail(1, str(e))
except Exception as e:
traceback.print_exc()

View File

@@ -18,9 +18,22 @@ class ModelProviderService:
self.ap = ap
@staticmethod
def _normalize_api_keys(api_key: str | None) -> list[str]:
normalized_api_key = api_key.strip() if api_key else ''
return [normalized_api_key] if normalized_api_key else []
def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]:
if api_keys is None:
return []
raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys)
normalized_keys = []
seen_keys = set()
for raw_key in raw_keys:
normalized_key = raw_key.strip() if isinstance(raw_key, str) else ''
if not normalized_key or normalized_key in seen_keys:
continue
normalized_keys.append(normalized_key)
seen_keys.add(normalized_key)
return normalized_keys
async def get_providers(self) -> list[dict]:
"""Get all providers"""
@@ -64,6 +77,7 @@ class ModelProviderService:
async def create_provider(self, provider_data: dict) -> str:
"""Create a new provider"""
provider_data['uuid'] = str(uuid.uuid4())
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
)
@@ -77,6 +91,8 @@ class ModelProviderService:
"""Update an existing provider"""
if 'uuid' in provider_data:
del provider_data['uuid']
if 'api_keys' in provider_data:
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == provider_uuid)
@@ -146,6 +162,8 @@ class ModelProviderService:
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
"""Find existing provider or create new one"""
api_keys = self._normalize_api_keys(api_keys)
# Try to find existing provider with same config
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.ModelProvider).where(
@@ -173,7 +191,7 @@ class ModelProviderService:
'name': provider_name,
'requester': requester,
'base_url': base_url,
'api_keys': api_keys or [],
'api_keys': api_keys,
}
)

View File

@@ -27,10 +27,7 @@ class WebPageBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
listeners: dict = pydantic.Field(default_factory=dict, exclude=True)
_ws_adapter: typing.Any = None
class Config:
arbitrary_types_allowed = True
# Allow private attributes
underscore_attrs_are_private = True
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
super().__init__(config=config, logger=logger, **kwargs)

View File

@@ -340,6 +340,7 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
"""Provider API请求器"""
name: str = None
init_api_key: str = 'langbot-init-placeholder'
ap: app.Application

View File

@@ -17,7 +17,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
"""OpenAI ChatCompletion API 请求器"""
client: openai.AsyncClient
init_api_key: str = 'langbot-init-placeholder'
default_config: dict[str, typing.Any] = {
'base_url': 'https://api.openai.com/v1',

View File

@@ -25,7 +25,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
async def initialize(self):
self.client = openai.AsyncClient(
api_key='',
api_key=self.init_api_key,
base_url=self.requester_cfg['base_url'],
timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),

View File

@@ -14,7 +14,14 @@ class TokenManager:
def __init__(self, name: str, tokens: list[str]):
self.name = name
self.tokens = tokens
self.tokens = []
seen_tokens = set()
for token in tokens:
normalized_token = token.strip() if isinstance(token, str) else ''
if not normalized_token or normalized_token in seen_tokens:
continue
self.tokens.append(normalized_token)
seen_tokens.add(normalized_token)
self.using_token_index = 0
def get_token(self) -> str:

View File

@@ -17,6 +17,8 @@ from langbot.pkg.pipeline.preproc.preproc import PreProcessor
from langbot.pkg.provider.modelmgr import requester
from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
from langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl import ModelScopeChatCompletions
from langbot.pkg.provider.modelmgr.token import TokenManager
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
@@ -66,6 +68,17 @@ def test_normalize_space_provider_api_keys_filters_blank_values():
assert ModelProviderService._normalize_api_keys('') == []
assert ModelProviderService._normalize_api_keys(' ') == []
assert ModelProviderService._normalize_api_keys(None) == []
assert ModelProviderService._normalize_api_keys([' first-key ', '', 'first-key', 'second-key']) == [
'first-key',
'second-key',
]
def test_token_manager_filters_blank_and_duplicate_tokens():
token_mgr = TokenManager('provider-uuid', [' first-key ', '', 'first-key', 'second-key', ' '])
assert token_mgr.tokens == ['first-key', 'second-key']
assert token_mgr.get_token() == 'first-key'
@pytest.mark.asyncio
@@ -85,6 +98,57 @@ async def test_openai_requester_initialize_uses_placeholder_api_key(monkeypatch)
assert captured_kwargs['api_key'] == OpenAIChatCompletions.init_api_key
@pytest.mark.asyncio
async def test_modelscope_requester_initialize_uses_placeholder_api_key(monkeypatch):
captured_kwargs = {}
def fake_client(**kwargs):
captured_kwargs.update(kwargs)
return SimpleNamespace(**kwargs)
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.openai.AsyncClient', fake_client)
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.httpx.AsyncClient', fake_client)
requester_inst = ModelScopeChatCompletions(ap=SimpleNamespace(), config={})
await requester_inst.initialize()
assert captured_kwargs['api_key'] == ModelScopeChatCompletions.init_api_key
@pytest.mark.asyncio
async def test_openai_embedding_call_overrides_placeholder_api_key():
captured_request = {}
async def fake_create(**kwargs):
captured_request['api_key'] = fake_client.api_key
captured_request['kwargs'] = kwargs
return SimpleNamespace(
data=[SimpleNamespace(embedding=[0.1, 0.2])],
usage=SimpleNamespace(prompt_tokens=3, total_tokens=3),
)
fake_client = SimpleNamespace(
api_key=OpenAIChatCompletions.init_api_key,
embeddings=SimpleNamespace(create=fake_create),
)
requester_inst = OpenAIChatCompletions(ap=SimpleNamespace(), config={})
requester_inst.client = fake_client
embeddings, usage_info = await requester_inst.invoke_embedding(
model=requester.RuntimeEmbeddingModel(
model_entity=SimpleNamespace(name='text-embedding-3-small', extra_args={}),
provider=SimpleNamespace(token_mgr=TokenManager('provider-uuid', [' runtime-key ', '', 'runtime-key'])),
),
input_text=['hello'],
)
assert captured_request['api_key'] == 'runtime-key'
assert captured_request['kwargs']['model'] == 'text-embedding-3-small'
assert embeddings == [[0.1, 0.2]]
assert usage_info == {'prompt_tokens': 3, 'total_tokens': 3}
@pytest.mark.asyncio
async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline():
from langbot.pkg.api.http.service.model import LLMModelsService

View File

@@ -1,4 +1,4 @@
import { useEffect, useState, useCallback, Suspense } from 'react';
import { useEffect, useState, useCallback, Suspense, useRef } from 'react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { httpClient } from '@/app/infra/http/HttpClient';
import { toast } from 'sonner';
@@ -20,10 +20,39 @@ import { Button } from '@/components/ui/button';
import { LoadingSpinner } from '@/components/ui/loading-spinner';
import langbotIcon from '@/app/assets/langbot-logo.webp';
type SpaceOAuthLoginResult = {
token: string;
user: string;
};
const pendingSpaceOAuthLogins = new Map<
string,
Promise<SpaceOAuthLoginResult>
>();
function getOrCreateSpaceOAuthLoginPromise(
authCode: string,
): Promise<SpaceOAuthLoginResult> {
const pendingRequest = pendingSpaceOAuthLogins.get(authCode);
if (pendingRequest) {
return pendingRequest;
}
const requestPromise = httpClient
.exchangeSpaceOAuthCode(authCode)
.finally(() => {
pendingSpaceOAuthLogins.delete(authCode);
});
pendingSpaceOAuthLogins.set(authCode, requestPromise);
return requestPromise;
}
function SpaceOAuthCallbackContent() {
const navigate = useNavigate();
const [searchParams] = useSearchParams();
const { t } = useTranslation();
const isMountedRef = useRef(true);
const [status, setStatus] = useState<
'loading' | 'confirm' | 'success' | 'error'
@@ -37,7 +66,11 @@ function SpaceOAuthCallbackContent() {
const handleOAuthCallback = useCallback(
async (authCode: string) => {
try {
const response = await httpClient.exchangeSpaceOAuthCode(authCode);
const response = await getOrCreateSpaceOAuthLoginPromise(authCode);
if (!isMountedRef.current) {
return;
}
localStorage.setItem('token', response.token);
if (response.user) {
localStorage.setItem('userEmail', response.user);
@@ -52,6 +85,10 @@ function SpaceOAuthCallbackContent() {
navigate(redirectTo);
}, 1000);
} catch (err) {
if (!isMountedRef.current) {
return;
}
setStatus('error');
const errorObj = err as { msg?: string };
const errMsg = (errorObj?.msg || '').toLowerCase();
@@ -72,6 +109,10 @@ function SpaceOAuthCallbackContent() {
setIsProcessing(true);
try {
const response = await httpClient.bindSpaceAccount(authCode, state);
if (!isMountedRef.current) {
return;
}
localStorage.setItem('token', response.token);
if (response.user) {
localStorage.setItem('userEmail', response.user);
@@ -82,6 +123,10 @@ function SpaceOAuthCallbackContent() {
navigate('/home');
}, 1000);
} catch (err) {
if (!isMountedRef.current) {
return;
}
setStatus('error');
const errorObj = err as { msg?: string };
const errMsg = (errorObj?.msg || '').toLowerCase();
@@ -91,13 +136,17 @@ function SpaceOAuthCallbackContent() {
setErrorMessage(t('account.bindSpaceFailed'));
}
} finally {
setIsProcessing(false);
if (isMountedRef.current) {
setIsProcessing(false);
}
}
},
[navigate, t],
);
useEffect(() => {
isMountedRef.current = true;
const authCode = searchParams.get('code');
const error = searchParams.get('error');
const errorDescription = searchParams.get('error_description');
@@ -135,6 +184,9 @@ function SpaceOAuthCallbackContent() {
// Normal login/register mode
handleOAuthCallback(authCode);
}
return () => {
isMountedRef.current = false;
};
}, [searchParams, handleOAuthCallback, t]);
const handleConfirmBind = () => {