diff --git a/src/langbot/pkg/api/http/service/model.py b/src/langbot/pkg/api/http/service/model.py index b670e99d..320104d8 100644 --- a/src/langbot/pkg/api/http/service/model.py +++ b/src/langbot/pkg/api/http/service/model.py @@ -23,6 +23,17 @@ def _parse_provider_api_keys(provider_dict: dict) -> dict: return provider_dict +def _runtime_model_data(model_uuid: str, model_data: dict) -> dict: + """Return model data for rebuilding runtime models after an update. + + Update payloads intentionally omit uuid before writing to the database. + Runtime model entities still need the stable uuid so pipeline configs can + resolve the in-memory model immediately after an edit, without requiring a + process restart. + """ + return {**model_data, 'uuid': model_uuid} + + class LLMModelsService: ap: app.Application @@ -173,7 +184,7 @@ class LLMModelsService: raise Exception('provider not found') runtime_llm_model = await self.ap.model_mgr.load_llm_model_with_provider( - persistence_model.LLMModel(**model_data), + persistence_model.LLMModel(**_runtime_model_data(model_uuid, model_data)), runtime_provider, ) self.ap.model_mgr.llm_models.append(runtime_llm_model) @@ -334,7 +345,7 @@ class EmbeddingModelsService: raise Exception('provider not found') runtime_embedding_model = await self.ap.model_mgr.load_embedding_model_with_provider( - persistence_model.EmbeddingModel(**model_data), + persistence_model.EmbeddingModel(**_runtime_model_data(model_uuid, model_data)), runtime_provider, ) self.ap.model_mgr.embedding_models.append(runtime_embedding_model) @@ -492,7 +503,7 @@ class RerankModelsService: raise Exception('provider not found') runtime_rerank_model = await self.ap.model_mgr.load_rerank_model_with_provider( - persistence_model.RerankModel(**model_data), + persistence_model.RerankModel(**_runtime_model_data(model_uuid, model_data)), runtime_provider, ) self.ap.model_mgr.rerank_models.append(runtime_rerank_model) diff --git a/tests/unit_tests/provider/__init__.py b/tests/unit_tests/provider/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/unit_tests/provider/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit_tests/provider/test_model_service.py b/tests/unit_tests/provider/test_model_service.py new file mode 100644 index 00000000..8fac8278 --- /dev/null +++ b/tests/unit_tests/provider/test_model_service.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock + +import pytest +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.platform.entities as platform_entities +import langbot_plugin.api.entities.builtin.platform.events as platform_events +import langbot_plugin.api.entities.builtin.platform.message as platform_message +import langbot_plugin.api.entities.builtin.provider.session as provider_session + +from langbot.pkg.api.http.service.model import _runtime_model_data +from langbot.pkg.entity.persistence import model as persistence_model +from langbot.pkg.pipeline.preproc.preproc import PreProcessor +from langbot.pkg.provider.modelmgr import requester +from langbot.pkg.provider.modelmgr.modelmgr import ModelManager +from langbot.pkg.provider.runners.localagent import LocalAgentRunner + + +def test_runtime_llm_model_data_preserves_uuid_after_update_payload_uuid_removed(): + update_payload = { + 'name': 'Qwen3.5-27B', + 'provider_uuid': 'provider-uuid', + 'abilities': [], + 'extra_args': {}, + } + + runtime_entity = persistence_model.LLMModel(**_runtime_model_data('model-uuid', update_payload)) + + assert runtime_entity.uuid == 'model-uuid' + assert runtime_entity.name == 'Qwen3.5-27B' + + +def test_runtime_embedding_model_data_preserves_uuid_after_update_payload_uuid_removed(): + update_payload = { + 'name': 'embedding-model', + 'provider_uuid': 'provider-uuid', + 'extra_args': {}, + } + + runtime_entity = persistence_model.EmbeddingModel(**_runtime_model_data('embedding-uuid', update_payload)) + + assert runtime_entity.uuid == 'embedding-uuid' + assert runtime_entity.name == 'embedding-model' + + +def test_runtime_rerank_model_data_preserves_uuid_after_update_payload_uuid_removed(): + update_payload = { + 'name': 'rerank-model', + 'provider_uuid': 'provider-uuid', + 'extra_args': {}, + } + + runtime_entity = persistence_model.RerankModel(**_runtime_model_data('rerank-uuid', update_payload)) + + assert runtime_entity.uuid == 'rerank-uuid' + assert runtime_entity.name == 'rerank-model' + + +@pytest.mark.asyncio +async def test_updated_llm_model_is_immediately_usable_by_local_agent_pipeline(): + from langbot.pkg.api.http.service.model import LLMModelsService + + model_uuid = 'qwen-model-uuid' + provider_uuid = 'ollama-provider-uuid' + + ap = SimpleNamespace() + ap.logger = Mock() + ap.persistence_mgr = SimpleNamespace(execute_async=AsyncMock()) + ap.tool_mgr = SimpleNamespace(get_all_tools=AsyncMock(return_value=[])) + ap.plugin_connector = SimpleNamespace( + emit_event=AsyncMock(return_value=SimpleNamespace(event=SimpleNamespace(default_prompt=[], prompt=[]))) + ) + + ap.model_mgr = ModelManager(ap) + runtime_provider = Mock() + ap.model_mgr.provider_dict = {provider_uuid: runtime_provider} + ap.model_mgr.llm_models = [ + requester.RuntimeLLMModel( + model_entity=persistence_model.LLMModel( + uuid=model_uuid, + name='old-qwen-name', + provider_uuid=provider_uuid, + abilities=[], + extra_args={}, + ), + provider=runtime_provider, + ) + ] + + await LLMModelsService(ap).update_llm_model( + model_uuid, + { + 'name': 'Qwen3.5-27B', + 'provider_uuid': provider_uuid, + 'abilities': [], + 'extra_args': {}, + }, + ) + + runtime_model = await ap.model_mgr.get_model_by_uuid(model_uuid) + assert runtime_model.model_entity.uuid == model_uuid + assert runtime_model.model_entity.name == 'Qwen3.5-27B' + + session = SimpleNamespace( + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=12345, + ) + conversation = SimpleNamespace( + uuid='conversation-uuid', + create_time=None, + update_time=None, + prompt=SimpleNamespace(messages=[], copy=Mock(return_value=SimpleNamespace(messages=[]))), + messages=[], + ) + ap.sess_mgr = SimpleNamespace( + get_session=AsyncMock(return_value=session), + get_conversation=AsyncMock(return_value=conversation), + ) + + message_chain = platform_message.MessageChain([platform_message.Plain(text='hello')]) + sender = platform_entities.Friend(id=12345, nickname='Tester', remark=None) + message_event = platform_events.FriendMessage( + type='FriendMessage', + sender=sender, + message_chain=message_chain, + time=1710000000, + ) + pipeline_config = { + 'ai': { + 'runner': {'runner': 'local-agent'}, + 'local-agent': { + 'model': {'primary': model_uuid, 'fallbacks': []}, + 'prompt': [], + 'knowledge-bases': [], + }, + }, + 'trigger': {'misc': {'combine-quote-message': False}}, + 'output': {'misc': {'remove-think': False}}, + } + query = pipeline_query.Query.model_construct( + query_id='query-id', + launcher_type=provider_session.LauncherTypes.PERSON, + launcher_id=12345, + sender_id=12345, + message_chain=message_chain, + message_event=message_event, + adapter=AsyncMock(), + pipeline_uuid='pipeline-uuid', + bot_uuid='bot-uuid', + pipeline_config=pipeline_config, + session=None, + prompt=None, + messages=[], + user_message=None, + use_funcs=[], + use_llm_model_uuid=None, + variables={}, + resp_messages=[], + resp_message_chain=None, + current_stage_name=None, + ) + + result = await PreProcessor(ap).process(query, 'PreProcessor') + processed_query = result.new_query + + assert processed_query.use_llm_model_uuid == model_uuid + + runner = SimpleNamespace(ap=ap, pipeline_config=pipeline_config) + candidates = await LocalAgentRunner._get_model_candidates(runner, processed_query) + + assert [model.model_entity.uuid for model in candidates] == [model_uuid]