From 4528000fc4059d2c88b88d523798ab8dd7d106c6 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Thu, 1 Jan 2026 02:00:24 +0800 Subject: [PATCH] refactor: model management --- .../pkg/api/http/controller/groups/user.py | 3 + src/langbot/pkg/api/http/service/model.py | 56 +++-- src/langbot/pkg/api/http/service/provider.py | 18 +- src/langbot/pkg/api/http/service/user.py | 3 + src/langbot/pkg/persistence/mgr.py | 18 +- src/langbot/pkg/plugin/handler.py | 2 +- src/langbot/pkg/provider/modelmgr/modelmgr.py | 192 ++++++++++-------- .../pkg/provider/modelmgr/requester.py | 43 ++-- .../modelmgr/requesters/anthropicmsgs.py | 4 +- .../modelmgr/requesters/bailianchatcmpl.py | 4 +- .../provider/modelmgr/requesters/chatcmpl.py | 6 +- .../modelmgr/requesters/deepseekchatcmpl.py | 2 +- .../modelmgr/requesters/geminichatcmpl.py | 2 +- .../modelmgr/requesters/jiekouaichatcmpl.py | 2 +- .../modelmgr/requesters/modelscopechatcmpl.py | 4 +- .../modelmgr/requesters/moonshotchatcmpl.py | 2 +- .../modelmgr/requesters/ppiochatcmpl.py | 2 +- .../pkg/provider/runners/localagent.py | 8 +- .../pkg/rag/knowledge/services/embedder.py | 2 +- .../pkg/rag/knowledge/services/retriever.py | 2 +- 20 files changed, 238 insertions(+), 137 deletions(-) diff --git a/src/langbot/pkg/api/http/controller/groups/user.py b/src/langbot/pkg/api/http/controller/groups/user.py index 086f9c28..b23a6c32 100644 --- a/src/langbot/pkg/api/http/controller/groups/user.py +++ b/src/langbot/pkg/api/http/controller/groups/user.py @@ -1,6 +1,7 @@ import quart import argon2 import asyncio +import traceback from .. import group @@ -141,8 +142,10 @@ class UserRouterGroup(group.RouterGroup): } ) except ValueError as e: + traceback.print_exc() return self.fail(1, str(e)) except Exception as e: + traceback.print_exc() return self.fail(2, f'OAuth callback failed: {str(e)}') @self.route('/info', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) diff --git a/src/langbot/pkg/api/http/service/model.py b/src/langbot/pkg/api/http/service/model.py index e12498cf..d6250ff6 100644 --- a/src/langbot/pkg/api/http/service/model.py +++ b/src/langbot/pkg/api/http/service/model.py @@ -85,10 +85,17 @@ class LLMModelsService: await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)) - llm_model = await self.get_llm_model(model_data['uuid']) - await self.ap.model_mgr.load_llm_model(llm_model) + runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid']) + if runtime_provider is None: + raise Exception('provider not found') - # Check if default pipeline has no model bound + runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider( + persistence_model.LLMModel(**model_data), + runtime_provider, + ) + self.ap.model_mgr.llm_models.append(runtime_llm_model) + + # set the default pipeline model to this model result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_pipeline.LegacyPipeline).where( persistence_pipeline.LegacyPipeline.is_default == True @@ -152,8 +159,16 @@ class LLMModelsService: ) await self.ap.model_mgr.remove_llm_model(model_uuid) - llm_model = await self.get_llm_model(model_uuid) - await self.ap.model_mgr.load_llm_model(llm_model) + + runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid']) + if runtime_provider is None: + raise Exception('provider not found') + + runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider( + persistence_model.LLMModel(**model_data), + runtime_provider, + ) + self.ap.model_mgr.llm_models.append(runtime_llm_model) async def delete_llm_model(self, model_uuid: str) -> None: """Delete an LLM model""" @@ -174,10 +189,10 @@ class LLMModelsService: if runtime_llm_model is None: raise Exception('model not found') else: - runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data) + runtime_llm_model = await self.ap.model_mgr.init_temporary_runtime_llm_model(model_data) extra_args = model_data.get('extra_args', {}) - await runtime_llm_model.requester.invoke_llm( + await runtime_llm_model.provider.requester.invoke_llm( query=None, model=runtime_llm_model, messages=[provider_message.Message(role='user', content='Hello, world! Please just reply a "Hello".')], @@ -244,8 +259,15 @@ class EmbeddingModelsService: sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data) ) - embedding_model = await self.get_embedding_model(model_data['uuid']) - await self.ap.model_mgr.load_embedding_model(embedding_model) + runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid']) + if runtime_provider is None: + raise Exception('provider not found') + + runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider( + persistence_model.EmbeddingModel(**model_data), + runtime_provider, + ) + self.ap.model_mgr.embedding_models.append(runtime_embedding_model) return model_data['uuid'] @@ -298,8 +320,16 @@ class EmbeddingModelsService: ) await self.ap.model_mgr.remove_embedding_model(model_uuid) - embedding_model = await self.get_embedding_model(model_uuid) - await self.ap.model_mgr.load_embedding_model(embedding_model) + + runtime_provider = self.ap.model_mgr.provider_dict.get(model_data['provider_uuid']) + if runtime_provider is None: + raise Exception('provider not found') + + runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider( + persistence_model.EmbeddingModel(**model_data), + runtime_provider, + ) + self.ap.model_mgr.embedding_models.append(runtime_embedding_model) async def delete_embedding_model(self, model_uuid: str) -> None: """Delete an embedding model""" @@ -322,9 +352,9 @@ class EmbeddingModelsService: if runtime_embedding_model is None: raise Exception('model not found') else: - runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data) + runtime_embedding_model = await self.ap.model_mgr.init_temporary_runtime_embedding_model(model_data) - await runtime_embedding_model.requester.invoke_embedding( + await runtime_embedding_model.provider.requester.invoke_embedding( model=runtime_embedding_model, input_text=['Hello, world!'], extra_args={}, diff --git a/src/langbot/pkg/api/http/service/provider.py b/src/langbot/pkg/api/http/service/provider.py index eb99c092..1abb6e9f 100644 --- a/src/langbot/pkg/api/http/service/provider.py +++ b/src/langbot/pkg/api/http/service/provider.py @@ -61,6 +61,10 @@ class ModelProviderService: await self.ap.persistence_mgr.execute_async( sqlalchemy.insert(persistence_model.ModelProvider).values(**provider_data) ) + + # load to runtime + runtime_provider = await self.ap.model_mgr.load_provider(provider_data) + self.ap.model_mgr.provider_dict[runtime_provider.provider_entity.uuid] = runtime_provider return provider_data['uuid'] async def update_provider(self, provider_uuid: str, provider_data: dict) -> None: @@ -72,8 +76,7 @@ class ModelProviderService: .where(persistence_model.ModelProvider.uuid == provider_uuid) .values(**provider_data) ) - # Reload all models using this provider - await self.ap.model_mgr.load_models_from_db() + await self.ap.model_mgr.reload_provider(provider_uuid) async def delete_provider(self, provider_uuid: str) -> None: """Delete a provider (only if no models reference it)""" @@ -100,6 +103,8 @@ class ModelProviderService: ) ) + await self.ap.model_mgr.remove_provider(provider_uuid) + async def get_provider_model_counts(self, provider_uuid: str) -> dict: """Get count of models using this provider""" llm_result = await self.ap.persistence_mgr.execute_async( @@ -150,3 +155,12 @@ class ModelProviderService: 'api_keys': api_keys or [], } ) + + async def update_space_model_provider_api_keys(self, api_key: str) -> None: + """Update Space model provider API keys""" + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_model.ModelProvider) + .where(persistence_model.ModelProvider.uuid == '00000000-0000-0000-0000-000000000000') + .values(api_keys=[api_key]) + ) + await self.ap.model_mgr.reload_provider('00000000-0000-0000-0000-000000000000') diff --git a/src/langbot/pkg/api/http/service/user.py b/src/langbot/pkg/api/http/service/user.py index 2b37d794..d83ee0ff 100644 --- a/src/langbot/pkg/api/http/service/user.py +++ b/src/langbot/pkg/api/http/service/user.py @@ -149,6 +149,7 @@ class UserService: space_access_token_expires_at=expires_at, ) ) + await self.ap.provider_service.update_space_model_provider_api_keys(api_key) return await self.get_user_by_space_account_uuid(space_account_uuid) # Check if user with same email exists @@ -167,6 +168,7 @@ class UserService: space_access_token_expires_at=expires_at, ) ) + await self.ap.provider_service.update_space_model_provider_api_keys(api_key) return await self.get_user_by_email(email) # Check if system is already initialized @@ -189,6 +191,7 @@ class UserService: space_access_token_expires_at=expires_at, ) ) + await self.ap.provider_service.update_space_model_provider_api_keys(api_key) return await self.get_user_by_space_account_uuid(space_account_uuid) diff --git a/src/langbot/pkg/persistence/mgr.py b/src/langbot/pkg/persistence/mgr.py index 1311ff4f..8e147799 100644 --- a/src/langbot/pkg/persistence/mgr.py +++ b/src/langbot/pkg/persistence/mgr.py @@ -125,25 +125,39 @@ class PersistenceManager: await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data)) async def write_space_model_providers(self): + space_models_gateway_api_url = self.ap.instance_config.data.get('space', {}).get( + 'models_gateway_api_url', 'https://api.langbot.cloud/v1' + ) + # write space model providers result = await self.execute_async( sqlalchemy.select(persistence_model.ModelProvider).where( persistence_model.ModelProvider.requester == 'space-chat-completions' ) ) - if result.first() is None: + exists_space_chat_completions_model_provider = result.first() + + # api keys will be set/updated when the oauth callback + if exists_space_chat_completions_model_provider is None: self.ap.logger.info('Creating space model providers...') space_chat_completions_model_provider = { 'uuid': '00000000-0000-0000-0000-000000000000', 'name': 'LangBot Models', 'requester': 'space-chat-completions', - 'base_url': 'https://api.langbot.cloud/v1', + 'base_url': space_models_gateway_api_url, 'api_keys': [], } await self.execute_async( sqlalchemy.insert(persistence_model.ModelProvider).values(space_chat_completions_model_provider) ) + else: + if exists_space_chat_completions_model_provider.base_url != space_models_gateway_api_url: + await self.execute_async( + sqlalchemy.update(persistence_model.ModelProvider) + .where(persistence_model.ModelProvider.uuid == exists_space_chat_completions_model_provider.uuid) + .values({'base_url': space_models_gateway_api_url}) + ) # ================================= diff --git a/src/langbot/pkg/plugin/handler.py b/src/langbot/pkg/plugin/handler.py index 35eeec50..e9819e0d 100644 --- a/src/langbot/pkg/plugin/handler.py +++ b/src/langbot/pkg/plugin/handler.py @@ -324,7 +324,7 @@ class RuntimeConnectionHandler(handler.Handler): messages_obj = [provider_message.Message.model_validate(message) for message in messages] funcs_obj = [resource_tool.LLMTool.model_validate(func) for func in funcs] - result = await llm_model.requester.invoke_llm( + result = await llm_model.provider.requester.invoke_llm( query=None, model=llm_model, messages=messages_obj, diff --git a/src/langbot/pkg/provider/modelmgr/modelmgr.py b/src/langbot/pkg/provider/modelmgr/modelmgr.py index 0c84eaad..de72f1bc 100644 --- a/src/langbot/pkg/provider/modelmgr/modelmgr.py +++ b/src/langbot/pkg/provider/modelmgr/modelmgr.py @@ -16,6 +16,9 @@ class ModelManager: ap: app.Application + provider_dict: dict[str, requester.RuntimeProvider] + """运行时模型提供商字典, uuid -> RuntimeProvider""" + llm_models: list[requester.RuntimeLLMModel] embedding_models: list[requester.RuntimeEmbeddingModel] @@ -51,23 +54,31 @@ class ModelManager: self.embedding_models = [] # Load all providers first + self.provider_dict = {} providers_result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_model.ModelProvider) ) - providers = {p.uuid: p for p in providers_result.all()} + for provider in providers_result.all(): + try: + runtime_provider = await self.load_provider(provider) + self.provider_dict[provider.uuid] = runtime_provider + except provider_errors.RequesterNotFoundError as e: + self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping provider {provider.uuid}') + continue + except Exception as e: + self.ap.logger.error(f'Failed to load provider {provider.uuid}: {e}\n{traceback.format_exc()}') # Load LLM models result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) llm_models = result.all() for llm_model in llm_models: try: - provider = providers.get(llm_model.provider_uuid) + provider = self.provider_dict.get(llm_model.provider_uuid) if provider is None: self.ap.logger.warning(f'Provider {llm_model.provider_uuid} not found for model {llm_model.uuid}') continue - await self.load_llm_model_with_provider(llm_model, provider) - except provider_errors.RequesterNotFoundError as e: - self.ap.logger.warning(f'Requester {e.requester_name} not found, skipping llm model {llm_model.uuid}') + runtime_llm_model = await self.load_llm_model_with_provider(llm_model, provider) + self.llm_models.append(runtime_llm_model) except Exception as e: self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}') @@ -76,17 +87,14 @@ class ModelManager: embedding_models = result.all() for embedding_model in embedding_models: try: - provider = providers.get(embedding_model.provider_uuid) + provider = self.provider_dict.get(embedding_model.provider_uuid) if provider is None: self.ap.logger.warning( f'Provider {embedding_model.provider_uuid} not found for model {embedding_model.uuid}' ) continue - await self.load_embedding_model_with_provider(embedding_model, provider) - except provider_errors.RequesterNotFoundError as e: - self.ap.logger.warning( - f'Requester {e.requester_name} not found, skipping embedding model {embedding_model.uuid}' - ) + runtime_embedding_model = await self.load_embedding_model_with_provider(embedding_model, provider) + self.embedding_models.append(runtime_embedding_model) except Exception as e: self.ap.logger.error(f'Failed to load model {embedding_model.uuid}: {e}\n{traceback.format_exc()}') @@ -149,123 +157,139 @@ class ModelManager: preserve_uuid=True, ) - async def init_runtime_llm_model( + async def init_temporary_runtime_llm_model( self, model_info: dict, - ): + ) -> requester.RuntimeLLMModel: """Initialize runtime LLM model from dict (for testing)""" provider_info = model_info.get('provider', {}) - requester_name = provider_info.get('requester', '') - base_url = provider_info.get('base_url', '') - api_keys = provider_info.get('api_keys', []) - if requester_name not in self.requester_dict: - raise provider_errors.RequesterNotFoundError(requester_name) - - requester_cfg = {'base_url': base_url} - requester_inst = self.requester_dict[requester_name](ap=self.ap, config=requester_cfg) - await requester_inst.initialize() - - # Create a temporary model entity - model_entity = persistence_model.LLMModel( - uuid=model_info.get('uuid', ''), - name=model_info.get('name', ''), - provider_uuid='', - abilities=model_info.get('abilities', []), - extra_args=model_info.get('extra_args', {}), - ) + runtime_provider = await self.load_provider(provider_info) runtime_llm_model = requester.RuntimeLLMModel( - model_entity=model_entity, - token_mgr=token.TokenManager(name=model_entity.uuid, tokens=api_keys), - requester=requester_inst, + model_entity=persistence_model.LLMModel( + uuid=model_info.get('uuid', ''), + name=model_info.get('name', ''), + provider_uuid='', + abilities=model_info.get('abilities', []), + extra_args=model_info.get('extra_args', {}), + ), + provider=runtime_provider, ) return runtime_llm_model - async def init_runtime_embedding_model( + async def init_temporary_runtime_embedding_model( self, model_info: dict, - ): + ) -> requester.RuntimeEmbeddingModel: """Initialize runtime embedding model from dict (for testing)""" provider_info = model_info.get('provider', {}) - requester_name = provider_info.get('requester', '') - base_url = provider_info.get('base_url', '') - api_keys = provider_info.get('api_keys', []) - - if requester_name not in self.requester_dict: - raise provider_errors.RequesterNotFoundError(requester_name) - - requester_cfg = {'base_url': base_url} - requester_inst = self.requester_dict[requester_name](ap=self.ap, config=requester_cfg) - await requester_inst.initialize() - - model_entity = persistence_model.EmbeddingModel( - uuid=model_info.get('uuid', ''), - name=model_info.get('name', ''), - provider_uuid='', - extra_args=model_info.get('extra_args', {}), - ) + runtime_provider = await self.load_provider(provider_info) runtime_embedding_model = requester.RuntimeEmbeddingModel( - model_entity=model_entity, - token_mgr=token.TokenManager(name=model_entity.uuid, tokens=api_keys), - requester=requester_inst, + model_entity=persistence_model.EmbeddingModel( + uuid=model_info.get('uuid', ''), + name=model_info.get('name', ''), + provider_uuid='', + extra_args=model_info.get('extra_args', {}), + ), + provider=runtime_provider, ) return runtime_embedding_model + async def load_provider( + self, provider_info: persistence_model.ModelProvider | sqlalchemy.Row | dict + ) -> requester.RuntimeProvider: + """Load provider from dict""" + if isinstance(provider_info, sqlalchemy.Row): + provider_entity = persistence_model.ModelProvider(**provider_info._mapping) + elif isinstance(provider_info, dict): + provider_entity = persistence_model.ModelProvider(**provider_info) + else: + provider_entity = provider_info + + if provider_entity.requester not in self.requester_dict: + raise provider_errors.RequesterNotFoundError(provider_entity.requester) + + requester_inst = self.requester_dict[provider_entity.requester]( + ap=self.ap, config={'base_url': provider_entity.base_url} + ) + await requester_inst.initialize() + + token_mgr = token.TokenManager(name=provider_entity.uuid, tokens=provider_entity.api_keys or []) + + provider = requester.RuntimeProvider( + provider_entity=provider_entity, + token_mgr=token_mgr, + requester=requester_inst, + ) + return provider + + async def remove_provider(self, provider_uuid: str): + """Remove provider + + This method will not consider the models using this provider, + because the models should be removed by the caller. + """ + del self.provider_dict[provider_uuid] + + async def reload_provider(self, provider_uuid: str): + """Reload provider""" + provider_entity = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.ModelProvider).where( + persistence_model.ModelProvider.uuid == provider_uuid + ) + ) + provider_entity = provider_entity.first() + if provider_entity is None: + raise provider_errors.ProviderNotFoundError(provider_uuid) + + new_runtime_provider = await self.load_provider(provider_entity) + + # update refs in runtime models + for model in self.llm_models: + if model.provider.provider_entity.uuid == provider_uuid: + model.provider = new_runtime_provider + for model in self.embedding_models: + if model.provider.provider_entity.uuid == provider_uuid: + model.provider = new_runtime_provider + + # update ref in provider dict + self.provider_dict[provider_uuid] = new_runtime_provider + async def load_llm_model_with_provider( self, model_info: persistence_model.LLMModel | sqlalchemy.Row, - provider: persistence_model.ModelProvider | sqlalchemy.Row, - ): + provider: requester.RuntimeProvider, + ) -> requester.RuntimeLLMModel: """Load LLM model with provider info""" if isinstance(model_info, sqlalchemy.Row): model_info = persistence_model.LLMModel(**model_info._mapping) - if isinstance(provider, sqlalchemy.Row): - provider = persistence_model.ModelProvider(**provider._mapping) - - if provider.requester not in self.requester_dict: - raise provider_errors.RequesterNotFoundError(provider.requester) - - requester_cfg = {'base_url': provider.base_url} - requester_inst = self.requester_dict[provider.requester](ap=self.ap, config=requester_cfg) - await requester_inst.initialize() runtime_llm_model = requester.RuntimeLLMModel( model_entity=model_info, - token_mgr=token.TokenManager(name=model_info.uuid, tokens=provider.api_keys or []), - requester=requester_inst, + provider=provider, ) - self.llm_models.append(runtime_llm_model) + return runtime_llm_model async def load_embedding_model_with_provider( self, model_info: persistence_model.EmbeddingModel | sqlalchemy.Row, - provider: persistence_model.ModelProvider | sqlalchemy.Row, - ): + provider: requester.RuntimeProvider, + ) -> requester.RuntimeEmbeddingModel: """Load embedding model with provider info""" if isinstance(model_info, sqlalchemy.Row): model_info = persistence_model.EmbeddingModel(**model_info._mapping) - if isinstance(provider, sqlalchemy.Row): - provider = persistence_model.ModelProvider(**provider._mapping) - - if provider.requester not in self.requester_dict: - raise provider_errors.RequesterNotFoundError(provider.requester) - - requester_cfg = {'base_url': provider.base_url} - requester_inst = self.requester_dict[provider.requester](ap=self.ap, config=requester_cfg) - await requester_inst.initialize() runtime_embedding_model = requester.RuntimeEmbeddingModel( model_entity=model_info, - token_mgr=token.TokenManager(name=model_info.uuid, tokens=provider.api_keys or []), - requester=requester_inst, + provider=provider, ) - self.embedding_models.append(runtime_embedding_model) + return runtime_embedding_model async def load_llm_model(self, model_info: dict): """Load LLM model from dict (with provider info)""" diff --git a/src/langbot/pkg/provider/modelmgr/requester.py b/src/langbot/pkg/provider/modelmgr/requester.py index 52d73eea..3052a62f 100644 --- a/src/langbot/pkg/provider/modelmgr/requester.py +++ b/src/langbot/pkg/provider/modelmgr/requester.py @@ -11,11 +11,11 @@ import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query import langbot_plugin.api.entities.builtin.provider.message as provider_message -class RuntimeLLMModel: - """运行时模型""" +class RuntimeProvider: + """运行时模型提供商""" - model_entity: persistence_model.LLMModel - """模型数据""" + provider_entity: persistence_model.ModelProvider + """提供商数据""" token_mgr: token.TokenManager """api key管理器""" @@ -25,36 +25,49 @@ class RuntimeLLMModel: def __init__( self, - model_entity: persistence_model.LLMModel, + provider_entity: persistence_model.ModelProvider, token_mgr: token.TokenManager, requester: ProviderAPIRequester, ): - self.model_entity = model_entity + self.provider_entity = provider_entity self.token_mgr = token_mgr self.requester = requester +class RuntimeLLMModel: + """运行时模型""" + + model_entity: persistence_model.LLMModel + """模型数据""" + + provider: RuntimeProvider + """提供商实例""" + + def __init__( + self, + model_entity: persistence_model.LLMModel, + provider: RuntimeProvider, + ): + self.model_entity = model_entity + self.provider = provider + + class RuntimeEmbeddingModel: """运行时 Embedding 模型""" model_entity: persistence_model.EmbeddingModel """模型数据""" - token_mgr: token.TokenManager - """api key管理器""" - - requester: ProviderAPIRequester - """请求器实例""" + provider: RuntimeProvider + """提供商实例""" def __init__( self, model_entity: persistence_model.EmbeddingModel, - token_mgr: token.TokenManager, - requester: ProviderAPIRequester, + provider: RuntimeProvider, ): self.model_entity = model_entity - self.token_mgr = token_mgr - self.requester = requester + self.provider = provider class ProviderAPIRequester(metaclass=abc.ABCMeta): diff --git a/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 3a1b9384..9394b73d 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -56,7 +56,7 @@ class AnthropicMessages(requester.ProviderAPIRequester): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message: - self.client.api_key = model.token_mgr.get_token() + self.client.api_key = model.provider.token_mgr.get_token() args = extra_args.copy() args['model'] = model.model_entity.name @@ -190,7 +190,7 @@ class AnthropicMessages(requester.ProviderAPIRequester): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message: - self.client.api_key = model.token_mgr.get_token() + self.client.api_key = model.provider.token_mgr.get_token() args = extra_args.copy() args['model'] = model.model_entity.name diff --git a/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.py index c60165bb..9da6e1b4 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/bailianchatcmpl.py @@ -30,7 +30,7 @@ class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name @@ -117,7 +117,7 @@ class BailianChatCompletions(modelscopechatcmpl.ModelScopeChatCompletions): if is_use_dashscope_call: response = dashscope.MultiModalConversation.call( # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx" - api_key=use_model.token_mgr.get_token(), + api_key=use_model.provider.token_mgr.get_token(), model=use_model.model_entity.name, messages=messages, result_format='message', diff --git a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py index b940859e..beb45936 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -130,7 +130,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.MessageChunk: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name @@ -251,7 +251,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name @@ -337,7 +337,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): extra_args: dict[str, typing.Any] = {}, ) -> list[list[float]]: """调用 Embedding API""" - self.client.api_key = model.token_mgr.get_token() + self.client.api_key = model.provider.token_mgr.get_token() args = { 'model': model.model_entity.name, diff --git a/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index 83b2bfa4..a95371da 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -26,7 +26,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name diff --git a/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py index 9741e6b3..f934145e 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/geminichatcmpl.py @@ -29,7 +29,7 @@ class GeminiChatCompletions(chatcmpl.OpenAIChatCompletions): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.MessageChunk: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name diff --git a/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.py index 60001037..305ae21f 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/jiekouaichatcmpl.py @@ -109,7 +109,7 @@ class JieKouAIChatCompletions(chatcmpl.OpenAIChatCompletions): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name diff --git a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index 8684a677..0d92bb51 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -131,7 +131,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name @@ -181,7 +181,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name diff --git a/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py index aa3d0f4f..969392b4 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py @@ -27,7 +27,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name diff --git a/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.py b/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.py index 9658312b..1836bd62 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/ppiochatcmpl.py @@ -109,7 +109,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message | typing.AsyncGenerator[provider_message.MessageChunk, None]: - self.client.api_key = use_model.token_mgr.get_token() + self.client.api_key = use_model.provider.token_mgr.get_token() args = {} args['model'] = use_model.model_entity.name diff --git a/src/langbot/pkg/provider/runners/localagent.py b/src/langbot/pkg/provider/runners/localagent.py index 6c29415e..4197f076 100644 --- a/src/langbot/pkg/provider/runners/localagent.py +++ b/src/langbot/pkg/provider/runners/localagent.py @@ -130,7 +130,7 @@ class LocalAgentRunner(runner.RequestRunner): if not is_stream: # 非流式输出,直接请求 - msg = await use_llm_model.requester.invoke_llm( + msg = await use_llm_model.provider.requester.invoke_llm( query, use_llm_model, req_messages, @@ -147,7 +147,7 @@ class LocalAgentRunner(runner.RequestRunner): accumulated_content = '' # 从开始累积的所有内容 last_role = 'assistant' msg_sequence = 1 - async for msg in use_llm_model.requester.invoke_llm_stream( + async for msg in use_llm_model.provider.requester.invoke_llm_stream( query, use_llm_model, req_messages, @@ -250,7 +250,7 @@ class LocalAgentRunner(runner.RequestRunner): last_role = 'assistant' msg_sequence = first_end_sequence - async for msg in use_llm_model.requester.invoke_llm_stream( + async for msg in use_llm_model.provider.requester.invoke_llm_stream( query, use_llm_model, req_messages, @@ -306,7 +306,7 @@ class LocalAgentRunner(runner.RequestRunner): ) else: # 处理完所有调用,再次请求 - msg = await use_llm_model.requester.invoke_llm( + msg = await use_llm_model.provider.requester.invoke_llm( query, use_llm_model, req_messages, diff --git a/src/langbot/pkg/rag/knowledge/services/embedder.py b/src/langbot/pkg/rag/knowledge/services/embedder.py index c8a1c3d3..a067c90c 100644 --- a/src/langbot/pkg/rag/knowledge/services/embedder.py +++ b/src/langbot/pkg/rag/knowledge/services/embedder.py @@ -33,7 +33,7 @@ class Embedder(BaseService): await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts)) # get embeddings - embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding( + embeddings_list: list[list[float]] = await embedding_model.provider.requester.invoke_embedding( model=embedding_model, input_text=chunks, extra_args={}, # TODO: add extra args diff --git a/src/langbot/pkg/rag/knowledge/services/retriever.py b/src/langbot/pkg/rag/knowledge/services/retriever.py index dada8d5f..f2f4bef5 100644 --- a/src/langbot/pkg/rag/knowledge/services/retriever.py +++ b/src/langbot/pkg/rag/knowledge/services/retriever.py @@ -19,7 +19,7 @@ class Retriever(base_service.BaseService): f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}" ) - query_embedding: list[float] = await embedding_model.requester.invoke_embedding( + query_embedding: list[float] = await embedding_model.provider.requester.invoke_embedding( model=embedding_model, input_text=[query], extra_args={}, # TODO: add extra args