mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: enhance API key normalization and improve Space OAuth callback handling
This commit is contained in:
@@ -146,6 +146,7 @@ class UserRouterGroup(group.RouterGroup):
|
||||
return self.fail(3, str(e))
|
||||
except ValueError as e:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.warning(f'Space OAuth callback failed: {e}')
|
||||
return self.fail(1, str(e))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -18,9 +18,22 @@ class ModelProviderService:
|
||||
self.ap = ap
|
||||
|
||||
@staticmethod
|
||||
def _normalize_api_keys(api_key: str | None) -> list[str]:
|
||||
normalized_api_key = api_key.strip() if api_key else ''
|
||||
return [normalized_api_key] if normalized_api_key else []
|
||||
def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]:
|
||||
if api_keys is None:
|
||||
return []
|
||||
|
||||
raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys)
|
||||
normalized_keys = []
|
||||
seen_keys = set()
|
||||
|
||||
for raw_key in raw_keys:
|
||||
normalized_key = raw_key.strip() if isinstance(raw_key, str) else ''
|
||||
if not normalized_key or normalized_key in seen_keys:
|
||||
continue
|
||||
normalized_keys.append(normalized_key)
|
||||
seen_keys.add(normalized_key)
|
||||
|
||||
return normalized_keys
|
||||
|
||||
async def get_providers(self) -> list[dict]:
|
||||
"""Get all providers"""
|
||||
@@ -64,6 +77,7 @@ class ModelProviderService:
|
||||
async def create_provider(self, provider_data: dict) -> str:
|
||||
"""Create a new provider"""
|
||||
provider_data['uuid'] = str(uuid.uuid4())
|
||||
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
|
||||
)
|
||||
@@ -77,6 +91,8 @@ class ModelProviderService:
|
||||
"""Update an existing provider"""
|
||||
if 'uuid' in provider_data:
|
||||
del provider_data['uuid']
|
||||
if 'api_keys' in provider_data:
|
||||
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_model.ModelProvider)
|
||||
.where(persistence_model.ModelProvider.uuid == provider_uuid)
|
||||
@@ -146,6 +162,8 @@ class ModelProviderService:
|
||||
|
||||
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
||||
"""Find existing provider or create new one"""
|
||||
api_keys = self._normalize_api_keys(api_keys)
|
||||
|
||||
# Try to find existing provider with same config
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||
@@ -173,7 +191,7 @@ class ModelProviderService:
|
||||
'name': provider_name,
|
||||
'requester': requester,
|
||||
'base_url': base_url,
|
||||
'api_keys': api_keys or [],
|
||||
'api_keys': api_keys,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -27,10 +27,7 @@ class WebPageBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
listeners: dict = pydantic.Field(default_factory=dict, exclude=True)
|
||||
_ws_adapter: typing.Any = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
# Allow private attributes
|
||||
underscore_attrs_are_private = True
|
||||
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
super().__init__(config=config, logger=logger, **kwargs)
|
||||
|
||||
@@ -340,6 +340,7 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""Provider API请求器"""
|
||||
|
||||
name: str = None
|
||||
init_api_key: str = 'langbot-init-placeholder'
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
"""OpenAI ChatCompletion API 请求器"""
|
||||
|
||||
client: openai.AsyncClient
|
||||
init_api_key: str = 'langbot-init-placeholder'
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base_url': 'https://api.openai.com/v1',
|
||||
|
||||
@@ -25,7 +25,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
async def initialize(self):
|
||||
self.client = openai.AsyncClient(
|
||||
api_key='',
|
||||
api_key=self.init_api_key,
|
||||
base_url=self.requester_cfg['base_url'],
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
||||
|
||||
@@ -14,7 +14,14 @@ class TokenManager:
|
||||
|
||||
def __init__(self, name: str, tokens: list[str]):
|
||||
self.name = name
|
||||
self.tokens = tokens
|
||||
self.tokens = []
|
||||
seen_tokens = set()
|
||||
for token in tokens:
|
||||
normalized_token = token.strip() if isinstance(token, str) else ''
|
||||
if not normalized_token or normalized_token in seen_tokens:
|
||||
continue
|
||||
self.tokens.append(normalized_token)
|
||||
seen_tokens.add(normalized_token)
|
||||
self.using_token_index = 0
|
||||
|
||||
def get_token(self) -> str:
|
||||
|
||||
@@ -17,6 +17,8 @@ from langbot.pkg.pipeline.preproc.preproc import PreProcessor
|
||||
from langbot.pkg.provider.modelmgr import requester
|
||||
from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
|
||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||
from langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl import ModelScopeChatCompletions
|
||||
from langbot.pkg.provider.modelmgr.token import TokenManager
|
||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
||||
|
||||
|
||||
@@ -66,6 +68,17 @@ def test_normalize_space_provider_api_keys_filters_blank_values():
|
||||
assert ModelProviderService._normalize_api_keys('') == []
|
||||
assert ModelProviderService._normalize_api_keys(' ') == []
|
||||
assert ModelProviderService._normalize_api_keys(None) == []
|
||||
assert ModelProviderService._normalize_api_keys([' first-key ', '', 'first-key', 'second-key']) == [
|
||||
'first-key',
|
||||
'second-key',
|
||||
]
|
||||
|
||||
|
||||
def test_token_manager_filters_blank_and_duplicate_tokens():
|
||||
token_mgr = TokenManager('provider-uuid', [' first-key ', '', 'first-key', 'second-key', ' '])
|
||||
|
||||
assert token_mgr.tokens == ['first-key', 'second-key']
|
||||
assert token_mgr.get_token() == 'first-key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -85,6 +98,57 @@ async def test_openai_requester_initialize_uses_placeholder_api_key(monkeypatch)
|
||||
assert captured_kwargs['api_key'] == OpenAIChatCompletions.init_api_key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modelscope_requester_initialize_uses_placeholder_api_key(monkeypatch):
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_client(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return SimpleNamespace(**kwargs)
|
||||
|
||||
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.openai.AsyncClient', fake_client)
|
||||
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.httpx.AsyncClient', fake_client)
|
||||
|
||||
requester_inst = ModelScopeChatCompletions(ap=SimpleNamespace(), config={})
|
||||
await requester_inst.initialize()
|
||||
|
||||
assert captured_kwargs['api_key'] == ModelScopeChatCompletions.init_api_key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding_call_overrides_placeholder_api_key():
|
||||
captured_request = {}
|
||||
|
||||
async def fake_create(**kwargs):
|
||||
captured_request['api_key'] = fake_client.api_key
|
||||
captured_request['kwargs'] = kwargs
|
||||
return SimpleNamespace(
|
||||
data=[SimpleNamespace(embedding=[0.1, 0.2])],
|
||||
usage=SimpleNamespace(prompt_tokens=3, total_tokens=3),
|
||||
)
|
||||
|
||||
fake_client = SimpleNamespace(
|
||||
api_key=OpenAIChatCompletions.init_api_key,
|
||||
embeddings=SimpleNamespace(create=fake_create),
|
||||
)
|
||||
|
||||
requester_inst = OpenAIChatCompletions(ap=SimpleNamespace(), config={})
|
||||
requester_inst.client = fake_client
|
||||
|
||||
embeddings, usage_info = await requester_inst.invoke_embedding(
|
||||
model=requester.RuntimeEmbeddingModel(
|
||||
model_entity=SimpleNamespace(name='text-embedding-3-small', extra_args={}),
|
||||
provider=SimpleNamespace(token_mgr=TokenManager('provider-uuid', [' runtime-key ', '', 'runtime-key'])),
|
||||
),
|
||||
input_text=['hello'],
|
||||
)
|
||||
|
||||
assert captured_request['api_key'] == 'runtime-key'
|
||||
assert captured_request['kwargs']['model'] == 'text-embedding-3-small'
|
||||
assert embeddings == [[0.1, 0.2]]
|
||||
assert usage_info == {'prompt_tokens': 3, 'total_tokens': 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline():
|
||||
from langbot.pkg.api.http.service.model import LLMModelsService
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useState, useCallback, Suspense } from 'react';
|
||||
import { useEffect, useState, useCallback, Suspense, useRef } from 'react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
import { toast } from 'sonner';
|
||||
@@ -20,10 +20,39 @@ import { Button } from '@/components/ui/button';
|
||||
import { LoadingSpinner } from '@/components/ui/loading-spinner';
|
||||
import langbotIcon from '@/app/assets/langbot-logo.webp';
|
||||
|
||||
type SpaceOAuthLoginResult = {
|
||||
token: string;
|
||||
user: string;
|
||||
};
|
||||
|
||||
const pendingSpaceOAuthLogins = new Map<
|
||||
string,
|
||||
Promise<SpaceOAuthLoginResult>
|
||||
>();
|
||||
|
||||
function getOrCreateSpaceOAuthLoginPromise(
|
||||
authCode: string,
|
||||
): Promise<SpaceOAuthLoginResult> {
|
||||
const pendingRequest = pendingSpaceOAuthLogins.get(authCode);
|
||||
if (pendingRequest) {
|
||||
return pendingRequest;
|
||||
}
|
||||
|
||||
const requestPromise = httpClient
|
||||
.exchangeSpaceOAuthCode(authCode)
|
||||
.finally(() => {
|
||||
pendingSpaceOAuthLogins.delete(authCode);
|
||||
});
|
||||
|
||||
pendingSpaceOAuthLogins.set(authCode, requestPromise);
|
||||
return requestPromise;
|
||||
}
|
||||
|
||||
function SpaceOAuthCallbackContent() {
|
||||
const navigate = useNavigate();
|
||||
const [searchParams] = useSearchParams();
|
||||
const { t } = useTranslation();
|
||||
const isMountedRef = useRef(true);
|
||||
|
||||
const [status, setStatus] = useState<
|
||||
'loading' | 'confirm' | 'success' | 'error'
|
||||
@@ -37,7 +66,11 @@ function SpaceOAuthCallbackContent() {
|
||||
const handleOAuthCallback = useCallback(
|
||||
async (authCode: string) => {
|
||||
try {
|
||||
const response = await httpClient.exchangeSpaceOAuthCode(authCode);
|
||||
const response = await getOrCreateSpaceOAuthLoginPromise(authCode);
|
||||
if (!isMountedRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
localStorage.setItem('token', response.token);
|
||||
if (response.user) {
|
||||
localStorage.setItem('userEmail', response.user);
|
||||
@@ -52,6 +85,10 @@ function SpaceOAuthCallbackContent() {
|
||||
navigate(redirectTo);
|
||||
}, 1000);
|
||||
} catch (err) {
|
||||
if (!isMountedRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
setStatus('error');
|
||||
const errorObj = err as { msg?: string };
|
||||
const errMsg = (errorObj?.msg || '').toLowerCase();
|
||||
@@ -72,6 +109,10 @@ function SpaceOAuthCallbackContent() {
|
||||
setIsProcessing(true);
|
||||
try {
|
||||
const response = await httpClient.bindSpaceAccount(authCode, state);
|
||||
if (!isMountedRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
localStorage.setItem('token', response.token);
|
||||
if (response.user) {
|
||||
localStorage.setItem('userEmail', response.user);
|
||||
@@ -82,6 +123,10 @@ function SpaceOAuthCallbackContent() {
|
||||
navigate('/home');
|
||||
}, 1000);
|
||||
} catch (err) {
|
||||
if (!isMountedRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
setStatus('error');
|
||||
const errorObj = err as { msg?: string };
|
||||
const errMsg = (errorObj?.msg || '').toLowerCase();
|
||||
@@ -91,13 +136,17 @@ function SpaceOAuthCallbackContent() {
|
||||
setErrorMessage(t('account.bindSpaceFailed'));
|
||||
}
|
||||
} finally {
|
||||
if (isMountedRef.current) {
|
||||
setIsProcessing(false);
|
||||
}
|
||||
}
|
||||
},
|
||||
[navigate, t],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
isMountedRef.current = true;
|
||||
|
||||
const authCode = searchParams.get('code');
|
||||
const error = searchParams.get('error');
|
||||
const errorDescription = searchParams.get('error_description');
|
||||
@@ -135,6 +184,9 @@ function SpaceOAuthCallbackContent() {
|
||||
// Normal login/register mode
|
||||
handleOAuthCallback(authCode);
|
||||
}
|
||||
return () => {
|
||||
isMountedRef.current = false;
|
||||
};
|
||||
}, [searchParams, handleOAuthCallback, t]);
|
||||
|
||||
const handleConfirmBind = () => {
|
||||
|
||||
Reference in New Issue
Block a user