feat(provider): add API key normalization and update OpenAI requester initialization

This commit is contained in:
fdc310
2026-05-11 14:21:42 +08:00
parent 59bd581e88
commit ea13ef87f2
3 changed files with 35 additions and 2 deletions

View File

@@ -17,6 +17,11 @@ class ModelProviderService:
def __init__(self, ap: app.Application) -> None: def __init__(self, ap: app.Application) -> None:
self.ap = ap self.ap = ap
@staticmethod
def _normalize_api_keys(api_key: str | None) -> list[str]:
normalized_api_key = api_key.strip() if api_key else ''
return [normalized_api_key] if normalized_api_key else []
async def get_providers(self) -> list[dict]: async def get_providers(self) -> list[dict]:
"""Get all providers""" """Get all providers"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider)) result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.ModelProvider))
@@ -177,7 +182,7 @@ class ModelProviderService:
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.ModelProvider) sqlalchemy.update(persistence_model.ModelProvider)
.where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000') .where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000')
.values(api_keys=[api_key]) .values(api_keys=self._normalize_api_keys(api_key))
) )
await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000') await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000')

View File

@@ -17,6 +17,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
"""OpenAI ChatCompletion API 请求器""" """OpenAI ChatCompletion API 请求器"""
client: openai.AsyncClient client: openai.AsyncClient
init_api_key: str = 'langbot-init-placeholder'
default_config: dict[str, typing.Any] = { default_config: dict[str, typing.Any] = {
'base_url': 'https://api.openai.com/v1', 'base_url': 'https://api.openai.com/v1',
@@ -25,7 +26,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
async def initialize(self): async def initialize(self):
self.client = openai.AsyncClient( self.client = openai.AsyncClient(
api_key='', api_key=self.init_api_key,
base_url=self.requester_cfg['base_url'].replace(' ', ''), base_url=self.requester_cfg['base_url'].replace(' ', ''),
timeout=self.requester_cfg['timeout'], timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']), http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']),

View File

@@ -11,10 +11,12 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.provider.session as provider_session
from langbot.pkg.api.http.service.model import _runtime_model_data from langbot.pkg.api.http.service.model import _runtime_model_data
from langbot.pkg.api.http.service.provider import ModelProviderService
from langbot.pkg.entity.persistence import model as persistence_model from langbot.pkg.entity.persistence import model as persistence_model
from langbot.pkg.pipeline.preproc.preproc import PreProcessor from langbot.pkg.pipeline.preproc.preproc import PreProcessor
from langbot.pkg.provider.modelmgr import requester from langbot.pkg.provider.modelmgr import requester
from langbot.pkg.provider.modelmgr.modelmgr import ModelManager from langbot.pkg.provider.modelmgr.modelmgr import ModelManager
from langbot.pkg.provider.modelmgr.requesters.chatcmpl import OpenAIChatCompletions
from langbot.pkg.provider.runners.localagent import LocalAgentRunner from langbot.pkg.provider.runners.localagent import LocalAgentRunner
@@ -58,6 +60,31 @@ def test_runtime_rerank_model_data_preserves_uuid_after_update_payload_uuid_remo
assert runtime_entity.name == 'rerank-model' assert runtime_entity.name == 'rerank-model'
def test_normalize_space_provider_api_keys_filters_blank_values():
assert ModelProviderService._normalize_api_keys('space-key') == ['space-key']
assert ModelProviderService._normalize_api_keys(' trimmed-key ') == ['trimmed-key']
assert ModelProviderService._normalize_api_keys('') == []
assert ModelProviderService._normalize_api_keys(' ') == []
assert ModelProviderService._normalize_api_keys(None) == []
@pytest.mark.asyncio
async def test_openai_requester_initialize_uses_placeholder_api_key(monkeypatch):
captured_kwargs = {}
def fake_client(**kwargs):
captured_kwargs.update(kwargs)
return SimpleNamespace(**kwargs)
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.chatcmpl.openai.AsyncClient', fake_client)
monkeypatch.setattr('langbot.pkg.provider.modelmgr.requesters.chatcmpl.httpx.AsyncClient', fake_client)
requester_inst = OpenAIChatCompletions(ap=SimpleNamespace(), config={})
await requester_inst.initialize()
assert captured_kwargs['api_key'] == OpenAIChatCompletions.init_api_key
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline(): async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline():
from langbot.pkg.api.http.service.model import LLMModelsService from langbot.pkg.api.http.service.model import LLMModelsService