mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
Merge branch 'langbot-app:master' into master
This commit is contained in:
60
.github/workflows/lint.yml
vendored
Normal file
60
.github/workflows/lint.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
- dev
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff Lint & Format
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run ruff check
|
||||
run: uv run ruff check src
|
||||
|
||||
- name: Run ruff format
|
||||
run: uv run ruff format src --check
|
||||
|
||||
frontend:
|
||||
name: Frontend Lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '25'
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: pnpm install
|
||||
|
||||
- name: Run lint
|
||||
working-directory: web
|
||||
run: pnpm lint
|
||||
@@ -23,7 +23,7 @@ dependencies = [
|
||||
"pynacl>=1.5.0", # Required for Discord voice support
|
||||
"gewechat-client>=0.1.5",
|
||||
"lark-oapi>=1.4.15",
|
||||
"mcp>=1.20.0",
|
||||
"mcp>=1.25.0",
|
||||
"nakuru-project-idk>=0.0.2.1",
|
||||
"ollama>=0.4.8",
|
||||
"openai>1.0.0",
|
||||
|
||||
@@ -85,7 +85,6 @@ class QQOfficialClient:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
|
||||
body = await req.get_data()
|
||||
|
||||
print(f'[QQ Official] Received request, body length: {len(body)}')
|
||||
@@ -96,7 +95,6 @@ class QQOfficialClient:
|
||||
|
||||
payload = json.loads(body)
|
||||
|
||||
|
||||
if payload.get('op') == 13:
|
||||
validation_data = payload.get('d')
|
||||
if not validation_data:
|
||||
@@ -276,21 +274,21 @@ class QQOfficialClient:
|
||||
seed = bot_secret
|
||||
while len(seed) < target_size:
|
||||
seed *= 2
|
||||
return seed[:target_size].encode("utf-8")
|
||||
return seed[:target_size].encode('utf-8')
|
||||
|
||||
async def verify(self, validation_payload: dict):
|
||||
seed = await self.repeat_seed(self.secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
|
||||
event_ts = validation_payload.get("event_ts", "")
|
||||
plain_token = validation_payload.get("plain_token", "")
|
||||
event_ts = validation_payload.get('event_ts', '')
|
||||
plain_token = validation_payload.get('plain_token', '')
|
||||
msg = event_ts + plain_token
|
||||
|
||||
# sign
|
||||
signature = private_key.sign(msg.encode()).hex()
|
||||
|
||||
response = {
|
||||
"plain_token": plain_token,
|
||||
"signature": signature,
|
||||
'plain_token': plain_token,
|
||||
'signature': signature,
|
||||
}
|
||||
return response
|
||||
|
||||
@@ -36,7 +36,12 @@ class WecomBotEvent(dict):
|
||||
"""
|
||||
用户名称
|
||||
"""
|
||||
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid
|
||||
return (
|
||||
self.get('username', '')
|
||||
or self.get('from', {}).get('alias', '')
|
||||
or self.get('from', {}).get('name', '')
|
||||
or self.userid
|
||||
)
|
||||
|
||||
@property
|
||||
def chatname(self) -> str:
|
||||
|
||||
@@ -30,7 +30,6 @@ class WebhookRouterGroup(group.RouterGroup):
|
||||
适配器返回的响应
|
||||
"""
|
||||
try:
|
||||
|
||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
|
||||
if not runtime_bot:
|
||||
@@ -39,11 +38,9 @@ class WebhookRouterGroup(group.RouterGroup):
|
||||
if not runtime_bot.enable:
|
||||
return quart.jsonify({'error': 'Bot is disabled'}), 403
|
||||
|
||||
|
||||
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
|
||||
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
|
||||
|
||||
|
||||
response = await runtime_bot.adapter.handle_unified_webhook(
|
||||
bot_uuid=bot_uuid,
|
||||
path=path,
|
||||
|
||||
@@ -59,7 +59,16 @@ class BotService:
|
||||
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
||||
|
||||
# Webhook URL for unified webhook adapters (independent of bot running state)
|
||||
if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack', 'wecomcs', 'LINE', 'lark']:
|
||||
if persistence_bot['adapter'] in [
|
||||
'wecom',
|
||||
'wecombot',
|
||||
'officialaccount',
|
||||
'qqofficial',
|
||||
'slack',
|
||||
'wecomcs',
|
||||
'LINE',
|
||||
'lark',
|
||||
]:
|
||||
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||
webhook_url = f'/bots/{bot_uuid}'
|
||||
adapter_runtime_values['webhook_url'] = webhook_url
|
||||
|
||||
@@ -34,7 +34,6 @@ from .. import taskmgr
|
||||
from ...telemetry import telemetry as telemetry_module
|
||||
|
||||
|
||||
|
||||
@stage.stage_class('BuildAppStage')
|
||||
class BuildAppStage(stage.BootingStage):
|
||||
"""Build LangBot application"""
|
||||
|
||||
@@ -75,10 +75,17 @@ class RuntimeBot:
|
||||
|
||||
# Only add to query pool if no webhook requested to skip pipeline
|
||||
if not skip_pipeline:
|
||||
launcher_id = event.sender.id
|
||||
|
||||
if hasattr(adapter, 'get_launcher_id'):
|
||||
custom_launcher_id = adapter.get_launcher_id(event)
|
||||
if custom_launcher_id:
|
||||
launcher_id = custom_launcher_id
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
@@ -86,7 +93,7 @@ class RuntimeBot:
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid,
|
||||
)
|
||||
else:
|
||||
await self.logger.info(f'Pipeline skipped for person message due to webhook response')
|
||||
await self.logger.info('Pipeline skipped for person message due to webhook response')
|
||||
|
||||
async def on_group_message(
|
||||
event: platform_events.GroupMessage,
|
||||
@@ -111,10 +118,17 @@ class RuntimeBot:
|
||||
|
||||
# Only add to query pool if no webhook requested to skip pipeline
|
||||
if not skip_pipeline:
|
||||
launcher_id = event.group.id
|
||||
|
||||
if hasattr(adapter, 'get_launcher_id'):
|
||||
custom_launcher_id = adapter.get_launcher_id(event)
|
||||
if custom_launcher_id:
|
||||
launcher_id = custom_launcher_id
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
@@ -122,7 +136,7 @@ class RuntimeBot:
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid,
|
||||
)
|
||||
else:
|
||||
await self.logger.info(f'Pipeline skipped for group message due to webhook response')
|
||||
await self.logger.info('Pipeline skipped for group message due to webhook response')
|
||||
|
||||
self.adapter.register_listener(platform_events.FriendMessage, on_friend_message)
|
||||
self.adapter.register_listener(platform_events.GroupMessage, on_group_message)
|
||||
|
||||
@@ -244,7 +244,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
|
||||
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
|
||||
|
||||
|
||||
if message.message_type == 'text':
|
||||
element_list = []
|
||||
|
||||
@@ -310,7 +309,11 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
]
|
||||
elif message.message_type == 'audio':
|
||||
message_content['content'] = [
|
||||
{'tag': 'audio', 'file_key': message_content['file_key'], "duration": message_content.get('duration',0)}
|
||||
{
|
||||
'tag': 'audio',
|
||||
'file_key': message_content['file_key'],
|
||||
'duration': message_content.get('duration', 0),
|
||||
}
|
||||
]
|
||||
|
||||
for ele in message_content['content']:
|
||||
@@ -367,12 +370,9 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
audio_bytes = response.file.read()
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode()
|
||||
|
||||
|
||||
# Get content type from response headers
|
||||
content_type = response.raw.headers.get('content-type', 'audio/mpeg')
|
||||
|
||||
|
||||
|
||||
mime_main = content_type.split(';')[0].strip()
|
||||
ext = mimetypes.guess_extension(mime_main) or '.bin'
|
||||
temp_dir = tempfile.gettempdir()
|
||||
@@ -418,7 +418,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
file_bytes = response.file.read()
|
||||
file_base64 = base64.b64encode(file_bytes).decode()
|
||||
|
||||
|
||||
file_format = response.raw.headers['content-type']
|
||||
|
||||
file_size = len(file_bytes)
|
||||
@@ -453,7 +452,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
return platform_message.MessageChain(lb_msg_list)
|
||||
|
||||
|
||||
|
||||
@@ -197,6 +197,10 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
}
|
||||
if self.config['markdown_card'] is True:
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
|
||||
if message_source.source_platform_object.message.message_thread_id:
|
||||
args['message_thread_id'] = message_source.source_platform_object.message.message_thread_id
|
||||
|
||||
if quote_origin:
|
||||
args['reply_to_message_id'] = message_source.source_platform_object.message.id
|
||||
|
||||
@@ -231,8 +235,12 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
'chat_id': message_source.source_platform_object.effective_chat.id,
|
||||
'text': content,
|
||||
}
|
||||
if message_source.source_platform_object.message.message_thread_id:
|
||||
args['message_thread_id'] = message_source.source_platform_object.message.message_thread_id
|
||||
|
||||
if quote_origin:
|
||||
args['reply_to_message_id'] = message_source.source_platform_object.message.id
|
||||
|
||||
if self.config['markdown_card'] is True:
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
|
||||
@@ -260,6 +268,24 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
# self.seq = 1 # 消息回复结束之后重置seq
|
||||
self.msg_stream_id.pop(message_id) # 消息回复结束之后删除流式消息id
|
||||
|
||||
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
||||
if not isinstance(event.source_platform_object, Update):
|
||||
return None
|
||||
|
||||
message = event.source_platform_object.message
|
||||
if not message:
|
||||
return None
|
||||
|
||||
# specifically handle telegram forum topic and private thread(not supported by official client yet but supported by bot api)
|
||||
if message.message_thread_id:
|
||||
# check if it is a group
|
||||
if isinstance(event, platform_events.GroupMessage):
|
||||
return f'{event.group.id}#{message.message_thread_id}'
|
||||
elif isinstance(event, platform_events.FriendMessage):
|
||||
return f'{event.sender.id}#{message.message_thread_id}'
|
||||
|
||||
return None
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
is_stream = False
|
||||
if self.config.get('enable-stream-reply', None):
|
||||
|
||||
@@ -75,13 +75,15 @@ class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
for msg in message_chain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
chunks = split_string_by_bytes(msg.text)
|
||||
content_list.extend([
|
||||
content_list.extend(
|
||||
[
|
||||
{
|
||||
'type': 'text',
|
||||
'content': chunk,
|
||||
}
|
||||
for chunk in chunks
|
||||
])
|
||||
]
|
||||
)
|
||||
elif type(msg) is platform_message.Image:
|
||||
content_list.append(
|
||||
{
|
||||
|
||||
@@ -56,7 +56,7 @@ class WebhookPusher:
|
||||
# Check if any webhook responded with skip_pipeline=true
|
||||
for result in results:
|
||||
if isinstance(result, dict) and result.get('skip_pipeline') is True:
|
||||
self.logger.info(f'Webhook responded with skip_pipeline=true, skipping pipeline for person message')
|
||||
self.logger.info('Webhook responded with skip_pipeline=true, skipping pipeline for person message')
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -103,7 +103,7 @@ class WebhookPusher:
|
||||
# Check if any webhook responded with skip_pipeline=true
|
||||
for result in results:
|
||||
if isinstance(result, dict) and result.get('skip_pipeline') is True:
|
||||
self.logger.info(f'Webhook responded with skip_pipeline=true, skipping pipeline for group message')
|
||||
self.logger.info('Webhook responded with skip_pipeline=true, skipping pipeline for group message')
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -51,9 +51,10 @@ class SeekDBEmbedding(requester.ProviderAPIRequester):
|
||||
await self.initialize()
|
||||
|
||||
if self._embedding_function is None:
|
||||
raise RuntimeError("SeekDB embedding function initialization failed")
|
||||
raise RuntimeError('SeekDB embedding function initialization failed')
|
||||
|
||||
return self._embedding_function(input_text)
|
||||
except Exception as e:
|
||||
from .. import errors
|
||||
|
||||
raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}')
|
||||
|
||||
@@ -212,13 +212,20 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
if func.arguments:
|
||||
parameters = json.loads(func.arguments)
|
||||
else:
|
||||
parameters = {}
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query)
|
||||
|
||||
# Handle return value content
|
||||
tool_content = None
|
||||
if isinstance(func_ret, list) and len(func_ret) > 0 and isinstance(func_ret[0], provider_message.ContentElement):
|
||||
if (
|
||||
isinstance(func_ret, list)
|
||||
and len(func_ret) > 0
|
||||
and isinstance(func_ret[0], provider_message.ContentElement)
|
||||
):
|
||||
tool_content = func_ret
|
||||
else:
|
||||
tool_content = json.dumps(func_ret, ensure_ascii=False)
|
||||
|
||||
@@ -68,15 +68,16 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
return plain_text
|
||||
|
||||
async def _process_stream_response(self, response: aiohttp.ClientResponse) -> typing.AsyncGenerator[
|
||||
provider_message.Message, None]:
|
||||
async def _process_stream_response(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况"""
|
||||
full_content = ""
|
||||
full_content = ''
|
||||
chunk_idx = 0
|
||||
is_final = False
|
||||
message_idx = 0
|
||||
|
||||
buffer = ""
|
||||
buffer = ''
|
||||
decoder = json.JSONDecoder()
|
||||
|
||||
async for raw_chunk in response.content.iter_chunked(1024):
|
||||
@@ -129,7 +130,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
preview = chunk_str[:200]
|
||||
except Exception:
|
||||
preview = '<unavailable>'
|
||||
self.ap.logger.warning(f"Failed to process chunk: {e}; chunk preview: {preview}")
|
||||
self.ap.logger.warning(f'Failed to process chunk: {e}; chunk preview: {preview}')
|
||||
|
||||
# 流结束后,尝试解析残余 buffer
|
||||
if buffer:
|
||||
@@ -151,7 +152,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
)
|
||||
except Exception as e:
|
||||
preview = buffer[:200]
|
||||
self.ap.logger.warning(f"Failed to parse remaining buffer: {e}; buffer preview: {preview}")
|
||||
self.ap.logger.warning(f'Failed to parse remaining buffer: {e}; buffer preview: {preview}')
|
||||
|
||||
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用n8n webhook"""
|
||||
@@ -220,11 +221,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
if is_stream:
|
||||
# 流式请求
|
||||
async with session.post(
|
||||
self.webhook_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
timeout=self.timeout
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
@@ -236,11 +233,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
yield chunk
|
||||
else:
|
||||
async with session.post(
|
||||
self.webhook_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
timeout=self.timeout
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
try:
|
||||
async for chunk in self._process_stream_response(response):
|
||||
|
||||
@@ -194,7 +194,7 @@ class RuntimeMCPSession:
|
||||
|
||||
async def func(*, _tool=tool, **kwargs):
|
||||
if not self.session:
|
||||
raise Exception("MCP session is not connected")
|
||||
raise Exception('MCP session is not connected')
|
||||
|
||||
result = await self.session.call_tool(_tool.name, kwargs)
|
||||
if result.isError:
|
||||
@@ -202,7 +202,7 @@ class RuntimeMCPSession:
|
||||
for content in result.content:
|
||||
if content.type == 'text':
|
||||
error_texts.append(content.text)
|
||||
raise Exception("\n".join(error_texts) if error_texts else "Unknown error from MCP tool")
|
||||
raise Exception('\n'.join(error_texts) if error_texts else 'Unknown error from MCP tool')
|
||||
|
||||
result_contents: list[provider_message.ContentElement] = []
|
||||
for content in result.content:
|
||||
@@ -221,8 +221,8 @@ class RuntimeMCPSession:
|
||||
self.functions.append(
|
||||
resource_tool.LLMTool(
|
||||
name=tool.name,
|
||||
human_desc=tool.description or "",
|
||||
description=tool.description or "",
|
||||
human_desc=tool.description or '',
|
||||
description=tool.description or '',
|
||||
parameters=tool.inputSchema,
|
||||
func=func,
|
||||
)
|
||||
@@ -338,13 +338,10 @@ class MCPLoader(loader.ToolLoader):
|
||||
"""
|
||||
uuid_ = server_config.get('uuid')
|
||||
if not uuid_:
|
||||
self.ap.logger.warning(
|
||||
'Server UUID is None for MCP server, maybe testing in the config page.'
|
||||
)
|
||||
self.ap.logger.warning('Server UUID is None for MCP server, maybe testing in the config page.')
|
||||
uuid_ = str(uuid_module.uuid4())
|
||||
server_config['uuid'] = uuid_
|
||||
|
||||
|
||||
name = server_config['name']
|
||||
uuid = server_config['uuid']
|
||||
mode = server_config['mode']
|
||||
|
||||
@@ -55,12 +55,7 @@ class VectorDBManager:
|
||||
user = pgvector_config.get('user', 'postgres')
|
||||
password = pgvector_config.get('password', 'postgres')
|
||||
self.vector_db = PgVectorDatabase(
|
||||
self.ap,
|
||||
host=host,
|
||||
port=port,
|
||||
database=database,
|
||||
user=user,
|
||||
password=password
|
||||
self.ap, host=host, port=port, database=database, user=user, password=password
|
||||
)
|
||||
self.ap.logger.info('Initialized pgvector database backend.')
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from langbot.pkg.core import app
|
||||
class MilvusVectorDatabase(VectorDatabase):
|
||||
"""Milvus vector database implementation"""
|
||||
|
||||
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None, db_name: str = None):
|
||||
def __init__(self, ap: app.Application, uri: str = 'milvus.db', token: str = None, db_name: str = None):
|
||||
"""Initialize Milvus vector database
|
||||
|
||||
Args:
|
||||
@@ -34,9 +34,9 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name)
|
||||
else:
|
||||
self.client = MilvusClient(uri=self.uri, db_name=self.db_name)
|
||||
self.ap.logger.info(f"Connected to Milvus at {self.uri}")
|
||||
self.ap.logger.info(f'Connected to Milvus at {self.uri}')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Failed to connect to Milvus: {e}")
|
||||
self.ap.logger.error(f'Failed to connect to Milvus: {e}')
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
@@ -70,15 +70,11 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
"""
|
||||
index_params = IndexParams()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="AUTOINDEX",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.client.create_index,
|
||||
collection_name=collection,
|
||||
index_params=index_params
|
||||
field_name='vector',
|
||||
index_type='AUTOINDEX',
|
||||
metric_type='COSINE',
|
||||
)
|
||||
await asyncio.to_thread(self.client.create_index, collection_name=collection, index_params=index_params)
|
||||
|
||||
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None):
|
||||
"""Internal method to get or create a Milvus collection with proper configuration.
|
||||
@@ -94,9 +90,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
return collection
|
||||
|
||||
# Check if collection exists
|
||||
has_collection = await asyncio.to_thread(
|
||||
self.client.has_collection, collection_name=collection
|
||||
)
|
||||
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
|
||||
|
||||
if not has_collection:
|
||||
# Default dimension if not specified (for backward compatibility)
|
||||
@@ -104,24 +98,26 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
vector_size = 1536
|
||||
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name='id', dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
||||
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name='chunk_uuid', dtype=DataType.VARCHAR, max_length=255),
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors")
|
||||
schema = CollectionSchema(fields=fields, description='LangBot knowledge base vectors')
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
collection_name=collection,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
metric_type='COSINE',
|
||||
)
|
||||
|
||||
await self._ensure_vector_index(collection)
|
||||
self.ap.logger.info(f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX")
|
||||
self.ap.logger.info(
|
||||
f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX"
|
||||
)
|
||||
else:
|
||||
# Ensure index exists for existing collection
|
||||
await self._ensure_index_if_missing(collection)
|
||||
@@ -137,11 +133,8 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
collection: Normalized collection name
|
||||
"""
|
||||
try:
|
||||
indexes = await asyncio.to_thread(
|
||||
self.client.list_indexes,
|
||||
collection_name=collection
|
||||
)
|
||||
if "vector" not in indexes:
|
||||
indexes = await asyncio.to_thread(self.client.list_indexes, collection_name=collection)
|
||||
if 'vector' not in indexes:
|
||||
await self._ensure_vector_index(collection)
|
||||
self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'")
|
||||
except Exception as e:
|
||||
@@ -184,39 +177,30 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
data = []
|
||||
for i, vector_id in enumerate(ids):
|
||||
entry = {
|
||||
"id": vector_id,
|
||||
"vector": embeddings_list[i],
|
||||
'id': vector_id,
|
||||
'vector': embeddings_list[i],
|
||||
}
|
||||
# Add metadata fields
|
||||
if metadatas and i < len(metadatas):
|
||||
metadata = metadatas[i]
|
||||
# Add common metadata fields
|
||||
if "text" in metadata:
|
||||
entry["text"] = metadata["text"]
|
||||
if "file_id" in metadata:
|
||||
entry["file_id"] = metadata["file_id"]
|
||||
if "uuid" in metadata:
|
||||
entry["chunk_uuid"] = metadata["uuid"]
|
||||
if 'text' in metadata:
|
||||
entry['text'] = metadata['text']
|
||||
if 'file_id' in metadata:
|
||||
entry['file_id'] = metadata['file_id']
|
||||
if 'uuid' in metadata:
|
||||
entry['chunk_uuid'] = metadata['uuid']
|
||||
data.append(entry)
|
||||
|
||||
# Insert data into Milvus
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
collection_name=collection,
|
||||
data=data
|
||||
)
|
||||
await asyncio.to_thread(self.client.insert, collection_name=collection, data=data)
|
||||
|
||||
# Load collection for searching (Milvus requires this)
|
||||
await asyncio.to_thread(
|
||||
self.client.load_collection,
|
||||
collection_name=collection
|
||||
)
|
||||
await asyncio.to_thread(self.client.load_collection, collection_name=collection)
|
||||
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
||||
|
||||
async def search(
|
||||
self, collection: str, query_embedding: list[float], k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
||||
"""Search for similar vectors in Milvus collection
|
||||
|
||||
Args:
|
||||
@@ -231,10 +215,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Perform search
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {}
|
||||
}
|
||||
search_params = {'metric_type': 'COSINE', 'params': {}}
|
||||
|
||||
results = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
@@ -242,7 +223,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
data=[query_embedding],
|
||||
limit=k,
|
||||
search_params=search_params,
|
||||
output_fields=["text", "file_id", "chunk_uuid"]
|
||||
output_fields=['text', 'file_id', 'chunk_uuid'],
|
||||
)
|
||||
|
||||
# Convert results to Chroma-compatible format
|
||||
@@ -253,30 +234,24 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
|
||||
if results and len(results) > 0:
|
||||
for hit in results[0]:
|
||||
ids.append(hit.get("id", ""))
|
||||
distances.append(hit.get("distance", 0.0))
|
||||
ids.append(hit.get('id', ''))
|
||||
distances.append(hit.get('distance', 0.0))
|
||||
|
||||
# Build metadata from entity fields
|
||||
entity = hit.get("entity", {})
|
||||
entity = hit.get('entity', {})
|
||||
metadata = {}
|
||||
if "text" in entity:
|
||||
metadata["text"] = entity["text"]
|
||||
if "file_id" in entity:
|
||||
metadata["file_id"] = entity["file_id"]
|
||||
if "chunk_uuid" in entity:
|
||||
metadata["uuid"] = entity["chunk_uuid"]
|
||||
if 'text' in entity:
|
||||
metadata['text'] = entity['text']
|
||||
if 'file_id' in entity:
|
||||
metadata['file_id'] = entity['file_id']
|
||||
if 'chunk_uuid' in entity:
|
||||
metadata['uuid'] = entity['chunk_uuid']
|
||||
metadatas.append(metadata)
|
||||
|
||||
# Return in Chroma-compatible format (nested lists)
|
||||
result = {
|
||||
"ids": [ids],
|
||||
"distances": [distances],
|
||||
"metadatas": [metadatas]
|
||||
}
|
||||
result = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
|
||||
|
||||
self.ap.logger.info(
|
||||
f"Milvus search in '{collection}' returned {len(ids)} results"
|
||||
)
|
||||
self.ap.logger.info(f"Milvus search in '{collection}' returned {len(ids)} results")
|
||||
return result
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
@@ -290,14 +265,8 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Delete entities matching the file_id
|
||||
await asyncio.to_thread(
|
||||
self.client.delete,
|
||||
collection_name=collection,
|
||||
filter=f'file_id == "{file_id}"'
|
||||
)
|
||||
self.ap.logger.info(
|
||||
f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}"
|
||||
)
|
||||
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
|
||||
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete a Milvus collection
|
||||
@@ -310,14 +279,10 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
self._collections.discard(collection)
|
||||
|
||||
# Check if collection exists before attempting deletion
|
||||
has_collection = await asyncio.to_thread(
|
||||
self.client.has_collection, collection_name=collection
|
||||
)
|
||||
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
|
||||
|
||||
if has_collection:
|
||||
await asyncio.to_thread(
|
||||
self.client.drop_collection, collection_name=collection
|
||||
)
|
||||
await asyncio.to_thread(self.client.drop_collection, collection_name=collection)
|
||||
self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
|
||||
else:
|
||||
self.ap.logger.warning(f"Milvus collection '{collection}' not found")
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from sqlalchemy import create_engine, text, Column, String, Text
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
from langbot.pkg.core import app
|
||||
import uuid
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class PgVectorEntry(Base):
|
||||
"""SQLAlchemy model for pgvector entries"""
|
||||
|
||||
__tablename__ = 'langbot_vectors'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
@@ -31,11 +30,11 @@ class PgVectorDatabase(VectorDatabase):
|
||||
self,
|
||||
ap: app.Application,
|
||||
connection_string: str = None,
|
||||
host: str = "localhost",
|
||||
host: str = 'localhost',
|
||||
port: int = 5432,
|
||||
database: str = "langbot",
|
||||
user: str = "postgres",
|
||||
password: str = "postgres"
|
||||
database: str = 'langbot',
|
||||
user: str = 'postgres',
|
||||
password: str = 'postgres',
|
||||
):
|
||||
"""Initialize pgvector database
|
||||
|
||||
@@ -54,14 +53,10 @@ class PgVectorDatabase(VectorDatabase):
|
||||
if connection_string:
|
||||
self.connection_string = connection_string
|
||||
else:
|
||||
self.connection_string = (
|
||||
f"postgresql+psycopg://{user}:{password}@{host}:{port}/{database}"
|
||||
)
|
||||
self.connection_string = f'postgresql+psycopg://{user}:{password}@{host}:{port}/{database}'
|
||||
|
||||
self.async_connection_string = self.connection_string.replace(
|
||||
"postgresql://", "postgresql+asyncpg://"
|
||||
).replace(
|
||||
"postgresql+psycopg://", "postgresql+asyncpg://"
|
||||
self.async_connection_string = self.connection_string.replace('postgresql://', 'postgresql+asyncpg://').replace(
|
||||
'postgresql+psycopg://', 'postgresql+asyncpg://'
|
||||
)
|
||||
|
||||
self.engine = None
|
||||
@@ -75,35 +70,25 @@ class PgVectorDatabase(VectorDatabase):
|
||||
"""Initialize database connection and create tables"""
|
||||
try:
|
||||
# Create async engine for async operations
|
||||
self.async_engine = create_async_engine(
|
||||
self.async_connection_string,
|
||||
echo=False,
|
||||
pool_pre_ping=True
|
||||
)
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
self.async_engine = create_async_engine(self.async_connection_string, echo=False, pool_pre_ping=True)
|
||||
self.AsyncSessionLocal = async_sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
# Create sync engine for table creation
|
||||
sync_connection_string = self.connection_string.replace(
|
||||
"postgresql+asyncpg://", "postgresql+psycopg://"
|
||||
)
|
||||
sync_connection_string = self.connection_string.replace('postgresql+asyncpg://', 'postgresql+psycopg://')
|
||||
self.engine = create_engine(sync_connection_string, echo=False)
|
||||
|
||||
# Create pgvector extension and tables
|
||||
with self.engine.connect() as conn:
|
||||
# Enable pgvector extension
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
||||
conn.commit()
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(self.engine)
|
||||
|
||||
self.ap.logger.info(f"Connected to PostgreSQL with pgvector")
|
||||
self.ap.logger.info('Connected to PostgreSQL with pgvector')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Failed to connect to PostgreSQL: {e}")
|
||||
self.ap.logger.error(f'Failed to connect to PostgreSQL: {e}')
|
||||
raise
|
||||
|
||||
async def get_or_create_collection(self, collection: str):
|
||||
@@ -144,24 +129,20 @@ class PgVectorDatabase(VectorDatabase):
|
||||
id=vector_id,
|
||||
collection=collection,
|
||||
embedding=embeddings_list[i],
|
||||
text=metadata.get("text", ""),
|
||||
file_id=metadata.get("file_id", ""),
|
||||
chunk_uuid=metadata.get("uuid", "")
|
||||
text=metadata.get('text', ''),
|
||||
file_id=metadata.get('file_id', ''),
|
||||
chunk_uuid=metadata.get('uuid', ''),
|
||||
)
|
||||
session.add(entry)
|
||||
|
||||
await session.commit()
|
||||
self.ap.logger.info(
|
||||
f"Added {len(ids)} embeddings to pgvector collection '{collection}'"
|
||||
)
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to pgvector collection '{collection}'")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error adding embeddings to pgvector: {e}")
|
||||
self.ap.logger.error(f'Error adding embeddings to pgvector: {e}')
|
||||
raise
|
||||
|
||||
async def search(
|
||||
self, collection: str, query_embedding: list[float], k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
||||
"""Search for similar vectors using cosine distance
|
||||
|
||||
Args:
|
||||
@@ -177,7 +158,7 @@ class PgVectorDatabase(VectorDatabase):
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
try:
|
||||
# Use cosine distance for similarity search
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select
|
||||
|
||||
# Query for similar vectors
|
||||
stmt = (
|
||||
@@ -186,7 +167,7 @@ class PgVectorDatabase(VectorDatabase):
|
||||
PgVectorEntry.text,
|
||||
PgVectorEntry.file_id,
|
||||
PgVectorEntry.chunk_uuid,
|
||||
PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance')
|
||||
PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance'),
|
||||
)
|
||||
.filter(PgVectorEntry.collection == collection)
|
||||
.order_by(PgVectorEntry.embedding.cosine_distance(query_embedding))
|
||||
@@ -204,25 +185,17 @@ class PgVectorDatabase(VectorDatabase):
|
||||
for row in rows:
|
||||
ids.append(row.id)
|
||||
distances.append(float(row.distance))
|
||||
metadatas.append({
|
||||
"text": row.text or "",
|
||||
"file_id": row.file_id or "",
|
||||
"uuid": row.chunk_uuid or ""
|
||||
})
|
||||
|
||||
result_dict = {
|
||||
"ids": [ids],
|
||||
"distances": [distances],
|
||||
"metadatas": [metadatas]
|
||||
}
|
||||
|
||||
self.ap.logger.info(
|
||||
f"pgvector search in '{collection}' returned {len(ids)} results"
|
||||
metadatas.append(
|
||||
{'text': row.text or '', 'file_id': row.file_id or '', 'uuid': row.chunk_uuid or ''}
|
||||
)
|
||||
|
||||
result_dict = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
|
||||
|
||||
self.ap.logger.info(f"pgvector search in '{collection}' returned {len(ids)} results")
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Error searching pgvector: {e}")
|
||||
self.ap.logger.error(f'Error searching pgvector: {e}')
|
||||
raise
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
@@ -239,8 +212,7 @@ class PgVectorDatabase(VectorDatabase):
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(PgVectorEntry).where(
|
||||
PgVectorEntry.collection == collection,
|
||||
PgVectorEntry.file_id == file_id
|
||||
PgVectorEntry.collection == collection, PgVectorEntry.file_id == file_id
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
@@ -250,7 +222,7 @@ class PgVectorDatabase(VectorDatabase):
|
||||
)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error deleting from pgvector: {e}")
|
||||
self.ap.logger.error(f'Error deleting from pgvector: {e}')
|
||||
raise
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
@@ -266,16 +238,14 @@ class PgVectorDatabase(VectorDatabase):
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(PgVectorEntry).where(
|
||||
PgVectorEntry.collection == collection
|
||||
)
|
||||
stmt = delete(PgVectorEntry).where(PgVectorEntry.collection == collection)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
self.ap.logger.info(f"Deleted pgvector collection '{collection}'")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
self.ap.logger.error(f"Error deleting pgvector collection: {e}")
|
||||
self.ap.logger.error(f'Error deleting pgvector collection: {e}')
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
|
||||
@@ -3,10 +3,8 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from langbot.pkg.core import app
|
||||
from langbot.pkg.entity.persistence import model as persistence_model
|
||||
from langbot.pkg.vector.vdb import VectorDatabase
|
||||
|
||||
try:
|
||||
@@ -87,14 +85,16 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
self._collections: Dict[str, Any] = {}
|
||||
self._collection_configs: Dict[str, HNSWConfiguration] = {}
|
||||
|
||||
self._escape_table = str.maketrans({
|
||||
self._escape_table = str.maketrans(
|
||||
{
|
||||
'\x00': '',
|
||||
'\\': '\\\\',
|
||||
'"': '\\"',
|
||||
'\n': '\\n',
|
||||
'\r': '\\r',
|
||||
'\t': '\\t',
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None) -> Any:
|
||||
"""Internal method to get or create a collection with proper configuration."""
|
||||
@@ -133,8 +133,10 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
def _clean_metadata(self, meta: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""SeekDB metadata doesn't support \\ and ", insert will error 3104"""
|
||||
return {
|
||||
k: v.translate(self._escape_table) if isinstance(v, str)
|
||||
else v if v is None or isinstance(v, (int, float, bool))
|
||||
k: v.translate(self._escape_table)
|
||||
if isinstance(v, str)
|
||||
else v
|
||||
if v is None or isinstance(v, (int, float, bool))
|
||||
else str(v)
|
||||
for k, v in meta.items()
|
||||
if v is not None
|
||||
@@ -145,11 +147,7 @@ class SeekDBVectorDatabase(VectorDatabase):
|
||||
return await self._get_or_create_collection_internal(collection)
|
||||
|
||||
async def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: List[str],
|
||||
embeddings_list: List[List[float]],
|
||||
metadatas: List[Dict[str, Any]]
|
||||
self, collection: str, ids: List[str], embeddings_list: List[List[float]], metadatas: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Add vector embeddings to the specified collection.
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
|
||||
Returns:
|
||||
Updated configuration dictionary
|
||||
"""
|
||||
|
||||
def convert_value(value: str, original_value: Any) -> Any:
|
||||
"""Convert string value to appropriate type based on original value
|
||||
|
||||
@@ -90,11 +91,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_simple_string_override(self):
|
||||
"""Test overriding a simple string value"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
}
|
||||
}
|
||||
cfg = {'api': {'port': 5300}}
|
||||
|
||||
# Set environment variable
|
||||
os.environ['API__PORT'] = '8080'
|
||||
@@ -108,12 +105,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_nested_key_override(self):
|
||||
"""Test overriding nested keys with __ delimiter"""
|
||||
cfg = {
|
||||
'concurrency': {
|
||||
'pipeline': 20,
|
||||
'session': 1
|
||||
}
|
||||
}
|
||||
cfg = {'concurrency': {'pipeline': 20, 'session': 1}}
|
||||
|
||||
os.environ['CONCURRENCY__PIPELINE'] = '50'
|
||||
|
||||
@@ -126,14 +118,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_deep_nested_override(self):
|
||||
"""Test overriding deeply nested keys"""
|
||||
cfg = {
|
||||
'system': {
|
||||
'jwt': {
|
||||
'expire': 604800,
|
||||
'secret': ''
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg = {'system': {'jwt': {'expire': 604800, 'secret': ''}}}
|
||||
|
||||
os.environ['SYSTEM__JWT__EXPIRE'] = '86400'
|
||||
os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key'
|
||||
@@ -148,12 +133,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_underscore_in_key(self):
|
||||
"""Test keys with underscores like runtime_ws_url"""
|
||||
cfg = {
|
||||
'plugin': {
|
||||
'enable': True,
|
||||
'runtime_ws_url': 'ws://localhost:5400/control/ws'
|
||||
}
|
||||
}
|
||||
cfg = {'plugin': {'enable': True, 'runtime_ws_url': 'ws://localhost:5400/control/ws'}}
|
||||
|
||||
os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws'
|
||||
|
||||
@@ -165,12 +145,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_boolean_conversion(self):
|
||||
"""Test boolean value conversion"""
|
||||
cfg = {
|
||||
'plugin': {
|
||||
'enable': True,
|
||||
'enable_marketplace': False
|
||||
}
|
||||
}
|
||||
cfg = {'plugin': {'enable': True, 'enable_marketplace': False}}
|
||||
|
||||
os.environ['PLUGIN__ENABLE'] = 'false'
|
||||
os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true'
|
||||
@@ -185,14 +160,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_ignore_dict_type(self):
|
||||
"""Test that dict types are ignored"""
|
||||
cfg = {
|
||||
'database': {
|
||||
'use': 'sqlite',
|
||||
'sqlite': {
|
||||
'path': 'data/langbot.db'
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
|
||||
|
||||
# Try to override a dict value - should be ignored
|
||||
os.environ['DATABASE__SQLITE'] = 'new_value'
|
||||
@@ -207,13 +175,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_ignore_list_type(self):
|
||||
"""Test that list/array types are ignored"""
|
||||
cfg = {
|
||||
'admins': ['admin1', 'admin2'],
|
||||
'command': {
|
||||
'enable': True,
|
||||
'prefix': ['!', '!']
|
||||
}
|
||||
}
|
||||
cfg = {'admins': ['admin1', 'admin2'], 'command': {'enable': True, 'prefix': ['!', '!']}}
|
||||
|
||||
# Try to override list values - should be ignored
|
||||
os.environ['ADMINS'] = 'admin3'
|
||||
@@ -232,11 +194,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_lowercase_env_var_ignored(self):
|
||||
"""Test that lowercase environment variables are ignored"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
}
|
||||
}
|
||||
cfg = {'api': {'port': 5300}}
|
||||
|
||||
os.environ['api__port'] = '8080'
|
||||
|
||||
@@ -249,11 +207,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_no_double_underscore_ignored(self):
|
||||
"""Test that env vars without __ are ignored"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
}
|
||||
}
|
||||
cfg = {'api': {'port': 5300}}
|
||||
|
||||
os.environ['APIPORT'] = '8080'
|
||||
|
||||
@@ -266,11 +220,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_nonexistent_key_ignored(self):
|
||||
"""Test that env vars for non-existent keys are ignored"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
}
|
||||
}
|
||||
cfg = {'api': {'port': 5300}}
|
||||
|
||||
os.environ['API__NONEXISTENT'] = 'value'
|
||||
|
||||
@@ -283,11 +233,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_integer_conversion(self):
|
||||
"""Test integer value conversion"""
|
||||
cfg = {
|
||||
'concurrency': {
|
||||
'pipeline': 20
|
||||
}
|
||||
}
|
||||
cfg = {'concurrency': {'pipeline': 20}}
|
||||
|
||||
os.environ['CONCURRENCY__PIPELINE'] = '100'
|
||||
|
||||
@@ -300,18 +246,7 @@ class TestEnvOverrides:
|
||||
|
||||
def test_multiple_overrides(self):
|
||||
"""Test multiple environment variable overrides at once"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
},
|
||||
'concurrency': {
|
||||
'pipeline': 20,
|
||||
'session': 1
|
||||
},
|
||||
'plugin': {
|
||||
'enable': False
|
||||
}
|
||||
}
|
||||
cfg = {'api': {'port': 5300}, 'concurrency': {'pipeline': 20, 'session': 1}, 'plugin': {'enable': False}}
|
||||
|
||||
os.environ['API__PORT'] = '8080'
|
||||
os.environ['CONCURRENCY__PIPELINE'] = '50'
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Test plugin list filtering by component kinds."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
@@ -31,16 +30,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Tool',
|
||||
'metadata': {'name': 'tool1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -53,15 +43,8 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'KnowledgeRetriever',
|
||||
'metadata': {'name': 'retriever1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
||||
],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -73,16 +56,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Command',
|
||||
'metadata': {'name': 'cmd1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Command', 'metadata': {'name': 'cmd1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -94,16 +68,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'EventListener',
|
||||
'metadata': {'name': 'listener1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'EventListener', 'metadata': {'name': 'listener1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -116,23 +81,9 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'KnowledgeRetriever',
|
||||
'metadata': {'name': 'retriever2'}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Tool',
|
||||
'metadata': {'name': 'tool2'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever2'}}}},
|
||||
{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool2'}}}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -187,16 +138,7 @@ async def test_plugin_list_filter_no_filter():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Tool',
|
||||
'metadata': {'name': 'tool1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -209,15 +151,8 @@ async def test_plugin_list_filter_no_filter():
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'KnowledgeRetriever',
|
||||
'metadata': {'name': 'retriever1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -267,15 +202,8 @@ async def test_plugin_list_filter_empty_result():
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'KnowledgeRetriever',
|
||||
'metadata': {'name': 'retriever1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -321,16 +249,7 @@ async def test_plugin_list_filter_plugin_without_components():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Tool',
|
||||
'metadata': {'name': 'tool1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -342,7 +261,7 @@ async def test_plugin_list_filter_plugin_without_components():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': []
|
||||
'components': [],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -121,7 +121,8 @@ const enUS = {
|
||||
webhookHint:
|
||||
'Webhooks allow LangBot to push person and group message events to external systems',
|
||||
actions: 'Actions',
|
||||
apiKeyCreatedMessage: 'Please copy this API key.',
|
||||
apiKeyCreatedMessage:
|
||||
'Please copy this API key, if the button is invalid, please copy manually.',
|
||||
none: 'None',
|
||||
},
|
||||
notFound: {
|
||||
|
||||
@@ -123,7 +123,8 @@ const jaJP = {
|
||||
webhookHint:
|
||||
'Webhook を使用すると、LangBot は個人メッセージとグループメッセージイベントを外部システムにプッシュできます',
|
||||
actions: 'アクション',
|
||||
apiKeyCreatedMessage: 'この API キーをコピーしてください。',
|
||||
apiKeyCreatedMessage:
|
||||
'この API キーをコピーしてください。もしボタンが無効な場合は手動でコピーしてください。',
|
||||
none: 'なし',
|
||||
},
|
||||
notFound: {
|
||||
|
||||
@@ -114,7 +114,7 @@ const zhHans = {
|
||||
noWebhooks: '暂无 Webhook',
|
||||
webhookHint: 'Webhook 允许 LangBot 将个人消息和群消息事件推送到外部系统',
|
||||
actions: '操作',
|
||||
apiKeyCreatedMessage: '请复制此 API 密钥。',
|
||||
apiKeyCreatedMessage: '请复制此 API 密钥,若按钮无效,请手动复制。',
|
||||
none: '无',
|
||||
},
|
||||
notFound: {
|
||||
|
||||
@@ -114,7 +114,7 @@ const zhHant = {
|
||||
noWebhooks: '暫無 Webhook',
|
||||
webhookHint: 'Webhook 允許 LangBot 將個人訊息和群組訊息事件推送到外部系統',
|
||||
actions: '操作',
|
||||
apiKeyCreatedMessage: '請複製此 API 金鑰。',
|
||||
apiKeyCreatedMessage: '請複製此 API 金鑰,若按鈕無效,請手動複製。',
|
||||
none: '無',
|
||||
},
|
||||
notFound: {
|
||||
|
||||
Reference in New Issue
Block a user