mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
476 lines
14 KiB
Python
476 lines
14 KiB
Python
"""
|
|
Message and query factories for tests.
|
|
|
|
Provides reusable factories for creating message chains, events, and query objects.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, Mock
|
|
import typing
|
|
|
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
|
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
|
|
|
|
|
# Counter for generating unique IDs
|
|
_query_counter = 0
|
|
DEFAULT_RUNNER_ID = "plugin:langbot/local-agent/default"
|
|
|
|
|
|
def _next_query_id() -> int:
|
|
"""Generate a unique query ID."""
|
|
global _query_counter
|
|
_query_counter += 1
|
|
return _query_counter
|
|
|
|
|
|
# ============== Message Chain Factories ==============
|
|
|
|
|
|
def text_chain(text: str = "hello") -> platform_message.MessageChain:
|
|
"""Create a simple text message chain."""
|
|
return platform_message.MessageChain([
|
|
platform_message.Plain(text=text),
|
|
])
|
|
|
|
|
|
def group_text_chain(text: str = "hello") -> platform_message.MessageChain:
|
|
"""Create a group text message chain (same as text_chain, context provided by event)."""
|
|
return text_chain(text)
|
|
|
|
|
|
def mention_chain(
|
|
text: str = "hello",
|
|
target: typing.Union[int, str] = 12345,
|
|
) -> platform_message.MessageChain:
|
|
"""Create a message chain with @mention."""
|
|
return platform_message.MessageChain([
|
|
platform_message.At(target=target),
|
|
platform_message.Plain(text=f" {text}"),
|
|
])
|
|
|
|
|
|
def image_chain(
|
|
text: str = "",
|
|
url: str = "https://example.com/image.png",
|
|
) -> platform_message.MessageChain:
|
|
"""Create a message chain with an image."""
|
|
components = []
|
|
if text:
|
|
components.append(platform_message.Plain(text=text))
|
|
components.append(platform_message.Image(url=url))
|
|
return platform_message.MessageChain(components)
|
|
|
|
|
|
def command_chain(
|
|
command: str = "help",
|
|
prefix: str = "/",
|
|
) -> platform_message.MessageChain:
|
|
"""Create a command message chain."""
|
|
return platform_message.MessageChain([
|
|
platform_message.Plain(text=f"{prefix}{command}"),
|
|
])
|
|
|
|
|
|
# ============== Message Event Factories ==============
|
|
|
|
|
|
def friend_message_event(
|
|
message_chain: platform_message.MessageChain,
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
nickname: str = "TestUser",
|
|
) -> platform_events.FriendMessage:
|
|
"""Create a friend (private) message event."""
|
|
sender = platform_entities.Friend(
|
|
id=sender_id,
|
|
nickname=nickname,
|
|
remark=None,
|
|
)
|
|
return platform_events.FriendMessage(
|
|
type="FriendMessage",
|
|
sender=sender,
|
|
message_chain=message_chain,
|
|
time=1609459200,
|
|
)
|
|
|
|
|
|
def group_message_event(
|
|
message_chain: platform_message.MessageChain,
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
sender_name: str = "TestUser",
|
|
group_id: typing.Union[int, str] = 99999,
|
|
group_name: str = "TestGroup",
|
|
) -> platform_events.GroupMessage:
|
|
"""Create a group message event."""
|
|
group = platform_entities.Group(
|
|
id=group_id,
|
|
name=group_name,
|
|
permission=platform_entities.Permission.Member,
|
|
)
|
|
sender = platform_entities.GroupMember(
|
|
id=sender_id,
|
|
member_name=sender_name,
|
|
permission=platform_entities.Permission.Member,
|
|
group=group,
|
|
)
|
|
return platform_events.GroupMessage(
|
|
type="GroupMessage",
|
|
sender=sender,
|
|
message_chain=message_chain,
|
|
time=1609459200,
|
|
)
|
|
|
|
|
|
# ============== Mock Adapter Factory ==============
|
|
|
|
|
|
def mock_adapter() -> Mock:
|
|
"""Create a mock platform adapter."""
|
|
adapter = AsyncMock()
|
|
adapter.is_stream_output_supported = AsyncMock(return_value=False)
|
|
adapter.reply_message = AsyncMock()
|
|
adapter.reply_message_chunk = AsyncMock()
|
|
return adapter
|
|
|
|
|
|
# ============== Query Factories ==============
|
|
|
|
|
|
def _base_query(
|
|
message_chain: platform_message.MessageChain,
|
|
message_event: platform_events.MessageEvent,
|
|
launcher_type: provider_session.LauncherTypes,
|
|
launcher_id: typing.Union[int, str],
|
|
sender_id: typing.Union[int, str],
|
|
adapter: Mock,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a base query with model_construct to bypass validation."""
|
|
query_id = _next_query_id()
|
|
|
|
base_data = {
|
|
"query_id": query_id,
|
|
"launcher_type": launcher_type,
|
|
"launcher_id": launcher_id,
|
|
"sender_id": sender_id,
|
|
"message_chain": message_chain,
|
|
"message_event": message_event,
|
|
"adapter": adapter,
|
|
"pipeline_uuid": "test-pipeline-uuid",
|
|
"bot_uuid": "test-bot-uuid",
|
|
"pipeline_config": {
|
|
"ai": {
|
|
"runner": {"id": DEFAULT_RUNNER_ID},
|
|
"runner_config": {
|
|
DEFAULT_RUNNER_ID: {
|
|
"model": {"primary": "test-model-uuid", "fallbacks": []},
|
|
"prompt": [{"role": "system", "content": "test-prompt"}],
|
|
},
|
|
},
|
|
},
|
|
"output": {"misc": {"at-sender": False, "quote-origin": False}},
|
|
"trigger": {"misc": {"combine-quote-message": False}},
|
|
},
|
|
"session": None,
|
|
"prompt": None,
|
|
"messages": [],
|
|
"user_message": None,
|
|
"use_funcs": [],
|
|
"use_llm_model_uuid": None,
|
|
"variables": {},
|
|
"resp_messages": [],
|
|
"resp_message_chain": None,
|
|
"current_stage_name": None,
|
|
}
|
|
|
|
# Apply overrides
|
|
for key, value in overrides.items():
|
|
base_data[key] = value
|
|
|
|
return pipeline_query.Query.model_construct(**base_data)
|
|
|
|
|
|
def text_query(
|
|
text: str = "hello",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a basic text query (private chat)."""
|
|
chain = text_chain(text)
|
|
event = friend_message_event(chain, sender_id)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
|
launcher_id=sender_id,
|
|
sender_id=sender_id,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|
|
|
|
|
|
def private_text_query(
|
|
text: str = "hello",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a private text query (alias for text_query)."""
|
|
return text_query(text, sender_id, **overrides)
|
|
|
|
|
|
def group_text_query(
|
|
text: str = "hello",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
group_id: typing.Union[int, str] = 99999,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a group text query."""
|
|
chain = text_chain(text)
|
|
event = group_message_event(chain, sender_id, group_id=group_id)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.GROUP,
|
|
launcher_id=group_id,
|
|
sender_id=sender_id,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|
|
|
|
|
|
def command_query(
|
|
command: str = "help",
|
|
prefix: str = "/",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a command-like query."""
|
|
chain = command_chain(command, prefix)
|
|
event = friend_message_event(chain, sender_id)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
|
launcher_id=sender_id,
|
|
sender_id=sender_id,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|
|
|
|
|
|
def mention_query(
|
|
text: str = "hello",
|
|
target: typing.Union[int, str] = 12345,
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
group_id: typing.Union[int, str] = 99999,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a mention-bot query (group chat with @mention)."""
|
|
chain = mention_chain(text, target)
|
|
event = group_message_event(chain, sender_id, group_id=group_id)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.GROUP,
|
|
launcher_id=group_id,
|
|
sender_id=sender_id,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|
|
|
|
|
|
def empty_query(**overrides) -> pipeline_query.Query:
|
|
"""Create an empty message query."""
|
|
chain = platform_message.MessageChain([])
|
|
event = friend_message_event(chain)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
|
launcher_id=12345,
|
|
sender_id=12345,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|
|
|
|
|
|
def image_query(
|
|
text: str = "",
|
|
url: str = "https://example.com/image.png",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create an image query."""
|
|
chain = image_chain(text, url)
|
|
event = friend_message_event(chain, sender_id)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
|
launcher_id=sender_id,
|
|
sender_id=sender_id,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|
|
|
|
|
|
def file_query(
|
|
url: str = "https://example.com/document.pdf",
|
|
name: str = "document.pdf",
|
|
text: str = "",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a file attachment query."""
|
|
components = []
|
|
if text:
|
|
components.append(platform_message.Plain(text=text))
|
|
components.append(platform_message.File(url=url, name=name))
|
|
chain = platform_message.MessageChain(components)
|
|
event = friend_message_event(chain, sender_id)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
|
launcher_id=sender_id,
|
|
sender_id=sender_id,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|
|
|
|
|
|
def unsupported_query(
|
|
unsupported_type: str = "CustomComponent",
|
|
text: str = "",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a query with unsupported/unknown message segment."""
|
|
components = []
|
|
if text:
|
|
components.append(platform_message.Plain(text=text))
|
|
# Use Unknown component for unsupported types
|
|
components.append(platform_message.Unknown(text=f"Unsupported: {unsupported_type}"))
|
|
chain = platform_message.MessageChain(components)
|
|
event = friend_message_event(chain, sender_id)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
|
launcher_id=sender_id,
|
|
sender_id=sender_id,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|
|
|
|
|
|
def query_with_session(
|
|
text: str = "hello",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
session: provider_session.Session = None,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a query with a session object.
|
|
|
|
If session is None, creates a default session with empty conversation.
|
|
"""
|
|
if session is None:
|
|
# Create a default session
|
|
session = provider_session.Session(
|
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
|
launcher_id=sender_id,
|
|
sender_id=sender_id,
|
|
use_prompt_name="default",
|
|
using_conversation=None,
|
|
conversations=[],
|
|
)
|
|
|
|
return text_query(text, sender_id, session=session, **overrides)
|
|
|
|
|
|
def query_with_config(
|
|
text: str = "hello",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
pipeline_config: dict = None,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a query with custom pipeline configuration.
|
|
|
|
If pipeline_config is None, uses default config.
|
|
Useful for testing specific stage behaviors.
|
|
"""
|
|
if pipeline_config is None:
|
|
pipeline_config = {
|
|
"ai": {
|
|
"runner": {"runner": "local-agent"},
|
|
"local-agent": {
|
|
"model": {"primary": "test-model-uuid", "fallbacks": []},
|
|
"prompt": "test-prompt",
|
|
},
|
|
},
|
|
"output": {"misc": {"at-sender": False, "quote-origin": False}},
|
|
"trigger": {"misc": {"combine-quote-message": False}},
|
|
}
|
|
|
|
return text_query(text, sender_id, pipeline_config=pipeline_config, **overrides)
|
|
|
|
|
|
def voice_query(
|
|
url: str = "https://example.com/audio.mp3",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a voice/audio query."""
|
|
components = [
|
|
platform_message.Voice(url=url),
|
|
]
|
|
chain = platform_message.MessageChain(components)
|
|
event = friend_message_event(chain, sender_id)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
|
launcher_id=sender_id,
|
|
sender_id=sender_id,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|
|
|
|
|
|
def at_all_query(
|
|
text: str = "hello",
|
|
sender_id: typing.Union[int, str] = 12345,
|
|
group_id: typing.Union[int, str] = 99999,
|
|
**overrides,
|
|
) -> pipeline_query.Query:
|
|
"""Create a group query with @All mention."""
|
|
components = [
|
|
platform_message.AtAll(),
|
|
platform_message.Plain(text=f" {text}"),
|
|
]
|
|
chain = platform_message.MessageChain(components)
|
|
event = group_message_event(chain, sender_id, group_id=group_id)
|
|
adapter = mock_adapter()
|
|
return _base_query(
|
|
message_chain=chain,
|
|
message_event=event,
|
|
launcher_type=provider_session.LauncherTypes.GROUP,
|
|
launcher_id=group_id,
|
|
sender_id=sender_id,
|
|
adapter=adapter,
|
|
**overrides,
|
|
)
|