mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: enhance API key normalization and improve Space OAuth callback handling
This commit is contained in:
@@ -146,6 +146,7 @@ class UserRouterGroup(group.RouterGroup):
|
|||||||
return self.fail(3, str(e))
|
return self.fail(3, str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
self.ap.logger.warning(f'Space OAuth callback failed: {e}')
|
||||||
return self.fail(1, str(e))
|
return self.fail(1, str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|||||||
@@ -18,9 +18,22 @@ class ModelProviderService:
|
|||||||
self.ap = ap
|
self.ap = ap
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_api_keys(api_key: str | None) -> list[str]:
|
def _normalize_api_keys(api_keys: str | list[str] | tuple[str, ...] | None) -> list[str]:
|
||||||
normalized_api_key = api_key.strip() if api_key else ''
|
if api_keys is None:
|
||||||
return [normalized_api_key] if normalized_api_key else []
|
return []
|
||||||
|
|
||||||
|
raw_keys = [api_keys] if isinstance(api_keys, str) else list(api_keys)
|
||||||
|
normalized_keys = []
|
||||||
|
seen_keys = set()
|
||||||
|
|
||||||
|
for raw_key in raw_keys:
|
||||||
|
normalized_key = raw_key.strip() if isinstance(raw_key, str) else ''
|
||||||
|
if not normalized_key or normalized_key in seen_keys:
|
||||||
|
continue
|
||||||
|
normalized_keys.append(normalized_key)
|
||||||
|
seen_keys.add(normalized_key)
|
||||||
|
|
||||||
|
return normalized_keys
|
||||||
|
|
||||||
async def get_providers(self) -> list[dict]:
|
async def get_providers(self) -> list[dict]:
|
||||||
"""Get all providers"""
|
"""Get all providers"""
|
||||||
@@ -64,6 +77,7 @@ class ModelProviderService:
|
|||||||
async def create_provider(self, provider_data: dict) -> str:
|
async def create_provider(self, provider_data: dict) -> str:
|
||||||
"""Create a new provider"""
|
"""Create a new provider"""
|
||||||
provider_data['uuid'] = str(uuid.uuid4())
|
provider_data['uuid'] = str(uuid.uuid4())
|
||||||
|
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
|
sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data)
|
||||||
)
|
)
|
||||||
@@ -77,6 +91,8 @@ class ModelProviderService:
|
|||||||
"""Update an existing provider"""
|
"""Update an existing provider"""
|
||||||
if 'uuid' in provider_data:
|
if 'uuid' in provider_data:
|
||||||
del provider_data['uuid']
|
del provider_data['uuid']
|
||||||
|
if 'api_keys' in provider_data:
|
||||||
|
provider_data['api_keys'] = self._normalize_api_keys(provider_data.get('api_keys'))
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.update(persistence_model.ModelProvider)
|
sqlalchemy.update(persistence_model.ModelProvider)
|
||||||
.where(persistence_model.ModelProvider.uuid == provider_uuid)
|
.where(persistence_model.ModelProvider.uuid == provider_uuid)
|
||||||
@@ -146,6 +162,8 @@ class ModelProviderService:
|
|||||||
|
|
||||||
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
async def find_or_create_provider(self, requester: str, base_url: str, api_keys: list) -> str:
|
||||||
"""Find existing provider or create new one"""
|
"""Find existing provider or create new one"""
|
||||||
|
api_keys = self._normalize_api_keys(api_keys)
|
||||||
|
|
||||||
# Try to find existing provider with same config
|
# Try to find existing provider with same config
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_model.ModelProvider).where(
|
sqlalchemy.select(persistence_model.ModelProvider).where(
|
||||||
@@ -173,7 +191,7 @@ class ModelProviderService:
|
|||||||
'name': provider_name,
|
'name': provider_name,
|
||||||
'requester': requester,
|
'requester': requester,
|
||||||
'base_url': base_url,
|
'base_url': base_url,
|
||||||
'api_keys': api_keys or [],
|
'api_keys': api_keys,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -27,10 +27,7 @@ class WebPageBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
|||||||
listeners: dict = pydantic.Field(default_factory=dict, exclude=True)
|
listeners: dict = pydantic.Field(default_factory=dict, exclude=True)
|
||||||
_ws_adapter: typing.Any = None
|
_ws_adapter: typing.Any = None
|
||||||
|
|
||||||
class Config:
|
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||||||
arbitrary_types_allowed = True
|
|
||||||
# Allow private attributes
|
|
||||||
underscore_attrs_are_private = True
|
|
||||||
|
|
||||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||||
super().__init__(config=config, logger=logger, **kwargs)
|
super().__init__(config=config, logger=logger, **kwargs)
|
||||||
|
|||||||
@@ -340,6 +340,7 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
|||||||
"""Provider API请求器"""
|
"""Provider API请求器"""
|
||||||
|
|
||||||
name: str = None
|
name: str = None
|
||||||
|
init_api_key: str = 'langbot-init-placeholder'
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
|||||||
"""OpenAI ChatCompletion API 请求器"""
|
"""OpenAI ChatCompletion API 请求器"""
|
||||||
|
|
||||||
client: openai.AsyncClient
|
client: openai.AsyncClient
|
||||||
init_api_key: str = 'langbot-init-placeholder'
|
|
||||||
|
|
||||||
default_config: dict[str, typing.Any] = {
|
default_config: dict[str, typing.Any] = {
|
||||||
'base_url': 'https://api.openai.com/v1',
|
'base_url': 'https://api.openai.com/v1',
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
|||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
self.client = openai.AsyncClient(
|
self.client = openai.AsyncClient(
|
||||||
api_key='',
|
api_key=self.init_api_key,
|
||||||
base_url=self.requester_cfg['base_url'],
|
base_url=self.requester_cfg['base_url'],
|
||||||
timeout=self.requester_cfg['timeout'],
|
timeout=self.requester_cfg['timeout'],
|
||||||
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),
|
||||||
|
|||||||
@@ -14,7 +14,14 @@ class TokenManager:
|
|||||||
|
|
||||||
def __init__(self, name: str, tokens: list[str]):
|
def __init__(self, name: str, tokens: list[str]):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.tokens = tokens
|
self.tokens = []
|
||||||
|
seen_tokens = set()
|
||||||
|
for token in tokens:
|
||||||
|
normalized_token = token.strip() if isinstance(token, str) else ''
|
||||||
|
if not normalized_token or normalized_token in seen_tokens:
|
||||||
|
continue
|
||||||
|
self.tokens.append(normalized_token)
|
||||||
|
seen_tokens.add(normalized_token)
|
||||||
self.using_token_index = 0
|
self.using_token_index = 0
|
||||||
|
|
||||||
def get_token(self) -> str:
|
def get_token(self) -> str:
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ from langbot.pkg.pipeline.preproc.preproc import PreProcessor
|
|||||||
from langbot.pkg.provider.modelmgr import requester
|
from langbot.pkg.provider.modelmgr import requester
|
||||||
from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
|
from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
|
||||||
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
|
||||||
|
from langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl import ModelScopeChatCompletions
|
||||||
|
from langbot.pkg.provider.modelmgr.token import TokenManager
|
||||||
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
from langbot.pkg.provider.runners.localagent import LocalAgentRunner
|
||||||
|
|
||||||
|
|
||||||
@@ -66,6 +68,17 @@ def test_normalize_space_provider_api_keys_filters_blank_values():
|
|||||||
assert ModelProviderService._normalize_api_keys('') == []
|
assert ModelProviderService._normalize_api_keys('') == []
|
||||||
assert ModelProviderService._normalize_api_keys(' ') == []
|
assert ModelProviderService._normalize_api_keys(' ') == []
|
||||||
assert ModelProviderService._normalize_api_keys(None) == []
|
assert ModelProviderService._normalize_api_keys(None) == []
|
||||||
|
assert ModelProviderService._normalize_api_keys([' first-key ', '', 'first-key', 'second-key']) == [
|
||||||
|
'first-key',
|
||||||
|
'second-key',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_manager_filters_blank_and_duplicate_tokens():
|
||||||
|
token_mgr = TokenManager('provider-uuid', [' first-key ', '', 'first-key', 'second-key', ' '])
|
||||||
|
|
||||||
|
assert token_mgr.tokens == ['first-key', 'second-key']
|
||||||
|
assert token_mgr.get_token() == 'first-key'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -85,6 +98,57 @@ async def test_openai_requester_initialize_uses_placeholder_api_key(monkeypatch)
|
|||||||
assert captured_kwargs['api_key'] == OpenAIChatCompletions.init_api_key
|
assert captured_kwargs['api_key'] == OpenAIChatCompletions.init_api_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_modelscope_requester_initialize_uses_placeholder_api_key(monkeypatch):
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
def fake_client(**kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
return SimpleNamespace(**kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.openai.AsyncClient', fake_client)
|
||||||
|
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.modelscopechatcmpl.httpx.AsyncClient', fake_client)
|
||||||
|
|
||||||
|
requester_inst = ModelScopeChatCompletions(ap=SimpleNamespace(), config={})
|
||||||
|
await requester_inst.initialize()
|
||||||
|
|
||||||
|
assert captured_kwargs['api_key'] == ModelScopeChatCompletions.init_api_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_embedding_call_overrides_placeholder_api_key():
|
||||||
|
captured_request = {}
|
||||||
|
|
||||||
|
async def fake_create(**kwargs):
|
||||||
|
captured_request['api_key'] = fake_client.api_key
|
||||||
|
captured_request['kwargs'] = kwargs
|
||||||
|
return SimpleNamespace(
|
||||||
|
data=[SimpleNamespace(embedding=[0.1, 0.2])],
|
||||||
|
usage=SimpleNamespace(prompt_tokens=3, total_tokens=3),
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_client = SimpleNamespace(
|
||||||
|
api_key=OpenAIChatCompletions.init_api_key,
|
||||||
|
embeddings=SimpleNamespace(create=fake_create),
|
||||||
|
)
|
||||||
|
|
||||||
|
requester_inst = OpenAIChatCompletions(ap=SimpleNamespace(), config={})
|
||||||
|
requester_inst.client = fake_client
|
||||||
|
|
||||||
|
embeddings, usage_info = await requester_inst.invoke_embedding(
|
||||||
|
model=requester.RuntimeEmbeddingModel(
|
||||||
|
model_entity=SimpleNamespace(name='text-embedding-3-small', extra_args={}),
|
||||||
|
provider=SimpleNamespace(token_mgr=TokenManager('provider-uuid', [' runtime-key ', '', 'runtime-key'])),
|
||||||
|
),
|
||||||
|
input_text=['hello'],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured_request['api_key'] == 'runtime-key'
|
||||||
|
assert captured_request['kwargs']['model'] == 'text-embedding-3-small'
|
||||||
|
assert embeddings == [[0.1, 0.2]]
|
||||||
|
assert usage_info == {'prompt_tokens': 3, 'total_tokens': 3}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline():
|
async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline():
|
||||||
from langbot.pkg.api.http.service.model import LLMModelsService
|
from langbot.pkg.api.http.service.model import LLMModelsService
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useEffect, useState, useCallback, Suspense } from 'react';
|
import { useEffect, useState, useCallback, Suspense, useRef } from 'react';
|
||||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||||
import { toast } from 'sonner';
|
import { toast } from 'sonner';
|
||||||
@@ -20,10 +20,39 @@ import { Button } from '@/components/ui/button';
|
|||||||
import { LoadingSpinner } from '@/components/ui/loading-spinner';
|
import { LoadingSpinner } from '@/components/ui/loading-spinner';
|
||||||
import langbotIcon from '@/app/assets/langbot-logo.webp';
|
import langbotIcon from '@/app/assets/langbot-logo.webp';
|
||||||
|
|
||||||
|
type SpaceOAuthLoginResult = {
|
||||||
|
token: string;
|
||||||
|
user: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const pendingSpaceOAuthLogins = new Map<
|
||||||
|
string,
|
||||||
|
Promise<SpaceOAuthLoginResult>
|
||||||
|
>();
|
||||||
|
|
||||||
|
function getOrCreateSpaceOAuthLoginPromise(
|
||||||
|
authCode: string,
|
||||||
|
): Promise<SpaceOAuthLoginResult> {
|
||||||
|
const pendingRequest = pendingSpaceOAuthLogins.get(authCode);
|
||||||
|
if (pendingRequest) {
|
||||||
|
return pendingRequest;
|
||||||
|
}
|
||||||
|
|
||||||
|
const requestPromise = httpClient
|
||||||
|
.exchangeSpaceOAuthCode(authCode)
|
||||||
|
.finally(() => {
|
||||||
|
pendingSpaceOAuthLogins.delete(authCode);
|
||||||
|
});
|
||||||
|
|
||||||
|
pendingSpaceOAuthLogins.set(authCode, requestPromise);
|
||||||
|
return requestPromise;
|
||||||
|
}
|
||||||
|
|
||||||
function SpaceOAuthCallbackContent() {
|
function SpaceOAuthCallbackContent() {
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const [searchParams] = useSearchParams();
|
const [searchParams] = useSearchParams();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const isMountedRef = useRef(true);
|
||||||
|
|
||||||
const [status, setStatus] = useState<
|
const [status, setStatus] = useState<
|
||||||
'loading' | 'confirm' | 'success' | 'error'
|
'loading' | 'confirm' | 'success' | 'error'
|
||||||
@@ -37,7 +66,11 @@ function SpaceOAuthCallbackContent() {
|
|||||||
const handleOAuthCallback = useCallback(
|
const handleOAuthCallback = useCallback(
|
||||||
async (authCode: string) => {
|
async (authCode: string) => {
|
||||||
try {
|
try {
|
||||||
const response = await httpClient.exchangeSpaceOAuthCode(authCode);
|
const response = await getOrCreateSpaceOAuthLoginPromise(authCode);
|
||||||
|
if (!isMountedRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
localStorage.setItem('token', response.token);
|
localStorage.setItem('token', response.token);
|
||||||
if (response.user) {
|
if (response.user) {
|
||||||
localStorage.setItem('userEmail', response.user);
|
localStorage.setItem('userEmail', response.user);
|
||||||
@@ -52,6 +85,10 @@ function SpaceOAuthCallbackContent() {
|
|||||||
navigate(redirectTo);
|
navigate(redirectTo);
|
||||||
}, 1000);
|
}, 1000);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
if (!isMountedRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
setStatus('error');
|
setStatus('error');
|
||||||
const errorObj = err as { msg?: string };
|
const errorObj = err as { msg?: string };
|
||||||
const errMsg = (errorObj?.msg || '').toLowerCase();
|
const errMsg = (errorObj?.msg || '').toLowerCase();
|
||||||
@@ -72,6 +109,10 @@ function SpaceOAuthCallbackContent() {
|
|||||||
setIsProcessing(true);
|
setIsProcessing(true);
|
||||||
try {
|
try {
|
||||||
const response = await httpClient.bindSpaceAccount(authCode, state);
|
const response = await httpClient.bindSpaceAccount(authCode, state);
|
||||||
|
if (!isMountedRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
localStorage.setItem('token', response.token);
|
localStorage.setItem('token', response.token);
|
||||||
if (response.user) {
|
if (response.user) {
|
||||||
localStorage.setItem('userEmail', response.user);
|
localStorage.setItem('userEmail', response.user);
|
||||||
@@ -82,6 +123,10 @@ function SpaceOAuthCallbackContent() {
|
|||||||
navigate('/home');
|
navigate('/home');
|
||||||
}, 1000);
|
}, 1000);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
if (!isMountedRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
setStatus('error');
|
setStatus('error');
|
||||||
const errorObj = err as { msg?: string };
|
const errorObj = err as { msg?: string };
|
||||||
const errMsg = (errorObj?.msg || '').toLowerCase();
|
const errMsg = (errorObj?.msg || '').toLowerCase();
|
||||||
@@ -91,13 +136,17 @@ function SpaceOAuthCallbackContent() {
|
|||||||
setErrorMessage(t('account.bindSpaceFailed'));
|
setErrorMessage(t('account.bindSpaceFailed'));
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
setIsProcessing(false);
|
if (isMountedRef.current) {
|
||||||
|
setIsProcessing(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[navigate, t],
|
[navigate, t],
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
isMountedRef.current = true;
|
||||||
|
|
||||||
const authCode = searchParams.get('code');
|
const authCode = searchParams.get('code');
|
||||||
const error = searchParams.get('error');
|
const error = searchParams.get('error');
|
||||||
const errorDescription = searchParams.get('error_description');
|
const errorDescription = searchParams.get('error_description');
|
||||||
@@ -135,6 +184,9 @@ function SpaceOAuthCallbackContent() {
|
|||||||
// Normal login/register mode
|
// Normal login/register mode
|
||||||
handleOAuthCallback(authCode);
|
handleOAuthCallback(authCode);
|
||||||
}
|
}
|
||||||
|
return () => {
|
||||||
|
isMountedRef.current = false;
|
||||||
|
};
|
||||||
}, [searchParams, handleOAuthCallback, t]);
|
}, [searchParams, handleOAuthCallback, t]);
|
||||||
|
|
||||||
const handleConfirmBind = () => {
|
const handleConfirmBind = () => {
|
||||||
|
|||||||
Reference in New Issue
Block a user