feat: add GitHub Actions workflow for linting with Ruff (#1929)

* feat: add GitHub Actions workflow for linting with Ruff

* refactor: rename lint job and add formatting step to Ruff workflow

* chore: run ruff format

* chore: rename Ruff lint job to 'Lint' and add frontend linting workflow
This commit is contained in:
Junyan Qin (Chin)
2026-01-23 13:43:12 +08:00
committed by GitHub
parent e60cb6ad0e
commit fc6e414be4
17 changed files with 327 additions and 450 deletions

60
.github/workflows/lint.yml vendored Normal file
View File

@@ -0,0 +1,60 @@
name: Lint
on:
push:
branches:
- main
- master
- dev
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
jobs:
ruff:
name: Ruff Lint & Format
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run ruff check
run: uv run ruff check src
- name: Run ruff format
run: uv run ruff format src --check
frontend:
name: Frontend Lint
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '25'
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: Install dependencies
working-directory: web
run: pnpm install
- name: Run lint
working-directory: web
run: pnpm lint

View File

@@ -85,7 +85,6 @@ class QQOfficialClient:
req: Quart Request 对象 req: Quart Request 对象
""" """
try: try:
body = await req.get_data() body = await req.get_data()
print(f'[QQ Official] Received request, body length: {len(body)}') print(f'[QQ Official] Received request, body length: {len(body)}')
@@ -96,7 +95,6 @@ class QQOfficialClient:
payload = json.loads(body) payload = json.loads(body)
if payload.get('op') == 13: if payload.get('op') == 13:
validation_data = payload.get('d') validation_data = payload.get('d')
if not validation_data: if not validation_data:
@@ -276,21 +274,21 @@ class QQOfficialClient:
seed = bot_secret seed = bot_secret
while len(seed) < target_size: while len(seed) < target_size:
seed *= 2 seed *= 2
return seed[:target_size].encode("utf-8") return seed[:target_size].encode('utf-8')
async def verify(self, validation_payload: dict): async def verify(self, validation_payload: dict):
seed = await self.repeat_seed(self.secret) seed = await self.repeat_seed(self.secret)
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed) private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
event_ts = validation_payload.get("event_ts", "") event_ts = validation_payload.get('event_ts', '')
plain_token = validation_payload.get("plain_token", "") plain_token = validation_payload.get('plain_token', '')
msg = event_ts + plain_token msg = event_ts + plain_token
# sign # sign
signature = private_key.sign(msg.encode()).hex() signature = private_key.sign(msg.encode()).hex()
response = { response = {
"plain_token": plain_token, 'plain_token': plain_token,
"signature": signature, 'signature': signature,
} }
return response return response

View File

@@ -36,7 +36,12 @@ class WecomBotEvent(dict):
""" """
用户名称 用户名称
""" """
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid return (
self.get('username', '')
or self.get('from', {}).get('alias', '')
or self.get('from', {}).get('name', '')
or self.userid
)
@property @property
def chatname(self) -> str: def chatname(self) -> str:
@@ -121,7 +126,7 @@ class WecomBotEvent(dict):
消息id 消息id
""" """
return self.get('msgid', '') return self.get('msgid', '')
@property @property
def ai_bot_id(self) -> str: def ai_bot_id(self) -> str:
""" """

View File

@@ -30,7 +30,6 @@ class WebhookRouterGroup(group.RouterGroup):
适配器返回的响应 适配器返回的响应
""" """
try: try:
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid) runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
if not runtime_bot: if not runtime_bot:
@@ -39,11 +38,9 @@ class WebhookRouterGroup(group.RouterGroup):
if not runtime_bot.enable: if not runtime_bot.enable:
return quart.jsonify({'error': 'Bot is disabled'}), 403 return quart.jsonify({'error': 'Bot is disabled'}), 403
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'): if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501 return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
response = await runtime_bot.adapter.handle_unified_webhook( response = await runtime_bot.adapter.handle_unified_webhook(
bot_uuid=bot_uuid, bot_uuid=bot_uuid,
path=path, path=path,

View File

@@ -59,7 +59,16 @@ class BotService:
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
# Webhook URL for unified webhook adapters (independent of bot running state) # Webhook URL for unified webhook adapters (independent of bot running state)
if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack', 'wecomcs', 'LINE', 'lark']: if persistence_bot['adapter'] in [
'wecom',
'wecombot',
'officialaccount',
'qqofficial',
'slack',
'wecomcs',
'LINE',
'lark',
]:
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300') webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
webhook_url = f'/bots/{bot_uuid}' webhook_url = f'/bots/{bot_uuid}'
adapter_runtime_values['webhook_url'] = webhook_url adapter_runtime_values['webhook_url'] = webhook_url

View File

@@ -34,7 +34,6 @@ from .. import taskmgr
from ...telemetry import telemetry as telemetry_module from ...telemetry import telemetry as telemetry_module
@stage.stage_class('BuildAppStage') @stage.stage_class('BuildAppStage')
class BuildAppStage(stage.BootingStage): class BuildAppStage(stage.BootingStage):
"""Build LangBot application""" """Build LangBot application"""

View File

@@ -244,7 +244,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time)) lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
if message.message_type == 'text': if message.message_type == 'text':
element_list = [] element_list = []
@@ -310,7 +309,11 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
] ]
elif message.message_type == 'audio': elif message.message_type == 'audio':
message_content['content'] = [ message_content['content'] = [
{'tag': 'audio', 'file_key': message_content['file_key'], "duration": message_content.get('duration',0)} {
'tag': 'audio',
'file_key': message_content['file_key'],
'duration': message_content.get('duration', 0),
}
] ]
for ele in message_content['content']: for ele in message_content['content']:
@@ -367,12 +370,9 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
audio_bytes = response.file.read() audio_bytes = response.file.read()
audio_base64 = base64.b64encode(audio_bytes).decode() audio_base64 = base64.b64encode(audio_bytes).decode()
# Get content type from response headers # Get content type from response headers
content_type = response.raw.headers.get('content-type', 'audio/mpeg') content_type = response.raw.headers.get('content-type', 'audio/mpeg')
mime_main = content_type.split(';')[0].strip() mime_main = content_type.split(';')[0].strip()
ext = mimetypes.guess_extension(mime_main) or '.bin' ext = mimetypes.guess_extension(mime_main) or '.bin'
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
@@ -418,7 +418,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
file_bytes = response.file.read() file_bytes = response.file.read()
file_base64 = base64.b64encode(file_bytes).decode() file_base64 = base64.b64encode(file_bytes).decode()
file_format = response.raw.headers['content-type'] file_format = response.raw.headers['content-type']
file_size = len(file_bytes) file_size = len(file_bytes)
@@ -453,7 +452,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
) )
) )
return platform_message.MessageChain(lb_msg_list) return platform_message.MessageChain(lb_msg_list)

View File

@@ -18,52 +18,52 @@ import langbot_plugin.api.entities.builtin.platform.entities as platform_entitie
def split_string_by_bytes(text, limit=2048, encoding='utf-8'): def split_string_by_bytes(text, limit=2048, encoding='utf-8'):
""" """
Splits a string into a list of strings, where each part is at most 'limit' bytes. Splits a string into a list of strings, where each part is at most 'limit' bytes.
Args: Args:
text (str): The original string to split. text (str): The original string to split.
limit (int): The maximum byte size for each split part. limit (int): The maximum byte size for each split part.
encoding (str): The encoding to use (default is 'utf-8'). encoding (str): The encoding to use (default is 'utf-8').
Returns: Returns:
list: A list of split strings. list: A list of split strings.
""" """
# 1. Encode the entire string into bytes # 1. Encode the entire string into bytes
bytes_data = text.encode(encoding) bytes_data = text.encode(encoding)
total_len = len(bytes_data) total_len = len(bytes_data)
parts = [] parts = []
start = 0 start = 0
while start < total_len: while start < total_len:
# 2. Determine the end index for the current chunk # 2. Determine the end index for the current chunk
# It shouldn't exceed the total length # It shouldn't exceed the total length
end = min(start + limit, total_len) end = min(start + limit, total_len)
# 3. Slice the byte array # 3. Slice the byte array
chunk = bytes_data[start:end] chunk = bytes_data[start:end]
# 4. Attempt to decode the chunk # 4. Attempt to decode the chunk
# Use errors='ignore' to drop any partial bytes at the end of the chunk # Use errors='ignore' to drop any partial bytes at the end of the chunk
# (e.g., if a 3-byte character was cut after the 2nd byte) # (e.g., if a 3-byte character was cut after the 2nd byte)
part_str = chunk.decode(encoding, errors='ignore') part_str = chunk.decode(encoding, errors='ignore')
# 5. Calculate the actual byte length of the successfully decoded string # 5. Calculate the actual byte length of the successfully decoded string
# This tells us exactly where the valid character boundary ended # This tells us exactly where the valid character boundary ended
part_bytes = part_str.encode(encoding) part_bytes = part_str.encode(encoding)
part_len = len(part_bytes) part_len = len(part_bytes)
# Safety check: Prevent infinite loop if limit is too small (e.g., limit=1 for a Chinese char) # Safety check: Prevent infinite loop if limit is too small (e.g., limit=1 for a Chinese char)
if part_len == 0 and end < total_len: if part_len == 0 and end < total_len:
# Force advance by 1 byte to consume the un-decodable byte or raise error # Force advance by 1 byte to consume the un-decodable byte or raise error
# Here we just treat it as a part to avoid stuck loops, though it might be invalid # Here we just treat it as a part to avoid stuck loops, though it might be invalid
start += 1 start += 1
continue continue
parts.append(part_str) parts.append(part_str)
# 6. Move the start pointer by the actual length consumed # 6. Move the start pointer by the actual length consumed
start += part_len start += part_len
return parts return parts
@@ -75,13 +75,15 @@ class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
for msg in message_chain: for msg in message_chain:
if type(msg) is platform_message.Plain: if type(msg) is platform_message.Plain:
chunks = split_string_by_bytes(msg.text) chunks = split_string_by_bytes(msg.text)
content_list.extend([ content_list.extend(
{ [
'type': 'text', {
'content': chunk, 'type': 'text',
} 'content': chunk,
for chunk in chunks }
]) for chunk in chunks
]
)
elif type(msg) is platform_message.Image: elif type(msg) is platform_message.Image:
content_list.append( content_list.append(
{ {

View File

@@ -51,9 +51,10 @@ class SeekDBEmbedding(requester.ProviderAPIRequester):
await self.initialize() await self.initialize()
if self._embedding_function is None: if self._embedding_function is None:
raise RuntimeError("SeekDB embedding function initialization failed") raise RuntimeError('SeekDB embedding function initialization failed')
return self._embedding_function(input_text) return self._embedding_function(input_text)
except Exception as e: except Exception as e:
from .. import errors from .. import errors
raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}') raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}')

View File

@@ -218,10 +218,14 @@ class LocalAgentRunner(runner.RequestRunner):
parameters = {} parameters = {}
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query) func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query)
# Handle return value content # Handle return value content
tool_content = None tool_content = None
if isinstance(func_ret, list) and len(func_ret) > 0 and isinstance(func_ret[0], provider_message.ContentElement): if (
isinstance(func_ret, list)
and len(func_ret) > 0
and isinstance(func_ret[0], provider_message.ContentElement)
):
tool_content = func_ret tool_content = func_ret
else: else:
tool_content = json.dumps(func_ret, ensure_ascii=False) tool_content = json.dumps(func_ret, ensure_ascii=False)

View File

@@ -68,15 +68,16 @@ class N8nServiceAPIRunner(runner.RequestRunner):
return plain_text return plain_text
async def _process_stream_response(self, response: aiohttp.ClientResponse) -> typing.AsyncGenerator[ async def _process_stream_response(
provider_message.Message, None]: self, response: aiohttp.ClientResponse
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况""" """处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况"""
full_content = "" full_content = ''
chunk_idx = 0 chunk_idx = 0
is_final = False is_final = False
message_idx = 0 message_idx = 0
buffer = "" buffer = ''
decoder = json.JSONDecoder() decoder = json.JSONDecoder()
async for raw_chunk in response.content.iter_chunked(1024): async for raw_chunk in response.content.iter_chunked(1024):
@@ -129,7 +130,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
preview = chunk_str[:200] preview = chunk_str[:200]
except Exception: except Exception:
preview = '<unavailable>' preview = '<unavailable>'
self.ap.logger.warning(f"Failed to process chunk: {e}; chunk preview: {preview}") self.ap.logger.warning(f'Failed to process chunk: {e}; chunk preview: {preview}')
# 流结束后,尝试解析残余 buffer # 流结束后,尝试解析残余 buffer
if buffer: if buffer:
@@ -151,7 +152,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
) )
except Exception as e: except Exception as e:
preview = buffer[:200] preview = buffer[:200]
self.ap.logger.warning(f"Failed to parse remaining buffer: {e}; buffer preview: {preview}") self.ap.logger.warning(f'Failed to parse remaining buffer: {e}; buffer preview: {preview}')
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用n8n webhook""" """调用n8n webhook"""
@@ -165,7 +166,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
# 准备请求数据 # 准备请求数据
payload = { payload = {
# 基本消息内容 # 基本消息内容
'chatInput' :plain_text, # 考虑到之前用户直接用的message model这里添加新键 'chatInput': plain_text, # 考虑到之前用户直接用的message model这里添加新键
'message': plain_text, 'message': plain_text,
'user_message_text': plain_text, 'user_message_text': plain_text,
'conversation_id': query.session.using_conversation.uuid, 'conversation_id': query.session.using_conversation.uuid,
@@ -217,57 +218,49 @@ class N8nServiceAPIRunner(runner.RequestRunner):
# 调用webhook # 调用webhook
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
if is_stream: if is_stream:
# 流式请求 # 流式请求
async with session.post( async with session.post(
self.webhook_url, self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
json=payload, ) as response:
headers=headers, if response.status != 200:
auth=auth, error_text = await response.text()
timeout=self.timeout self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
) as response: raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
# 处理流式响应
async for chunk in self._process_stream_response(response):
yield chunk
else:
async with session.post(
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
) as response:
try:
async for chunk in self._process_stream_response(response):
output_content = chunk.content if chunk.is_final else ''
except:
# 非流式请求(保持原有逻辑)
if response.status != 200: if response.status != 200:
error_text = await response.text() error_text = await response.text()
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
# 处理流式响应 # 解析响应
async for chunk in self._process_stream_response(response): response_data = await response.json()
yield chunk self.ap.logger.debug(f'n8n webhook response: {response_data}')
else:
async with session.post(
self.webhook_url,
json=payload,
headers=headers,
auth=auth,
timeout=self.timeout
) as response:
try:
async for chunk in self._process_stream_response(response):
output_content = chunk.content if chunk.is_final else ''
except:
# 非流式请求(保持原有逻辑)
if response.status != 200:
error_text = await response.text()
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
# 解析响应 # 从响应中提取输出
response_data = await response.json() if self.output_key in response_data:
self.ap.logger.debug(f'n8n webhook response: {response_data}') output_content = response_data[self.output_key]
else:
# 如果没有指定的输出键,则使用整个响应
output_content = json.dumps(response_data, ensure_ascii=False)
# 从响应中提取输出 # 返回消息
if self.output_key in response_data: yield provider_message.Message(
output_content = response_data[self.output_key] role='assistant',
else: content=output_content,
# 如果没有指定的输出键,则使用整个响应 )
output_content = json.dumps(response_data, ensure_ascii=False)
# 返回消息
yield provider_message.Message(
role='assistant',
content=output_content,
)
except Exception as e: except Exception as e:
self.ap.logger.error(f'n8n webhook call exception: {str(e)}') self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
raise N8nAPIError(f'n8n webhook call exception: {str(e)}') raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
@@ -275,4 +268,4 @@ class N8nServiceAPIRunner(runner.RequestRunner):
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""运行请求""" """运行请求"""
async for msg in self._call_webhook(query): async for msg in self._call_webhook(query):
yield msg yield msg

View File

@@ -194,7 +194,7 @@ class RuntimeMCPSession:
async def func(*, _tool=tool, **kwargs): async def func(*, _tool=tool, **kwargs):
if not self.session: if not self.session:
raise Exception("MCP session is not connected") raise Exception('MCP session is not connected')
result = await self.session.call_tool(_tool.name, kwargs) result = await self.session.call_tool(_tool.name, kwargs)
if result.isError: if result.isError:
@@ -202,8 +202,8 @@ class RuntimeMCPSession:
for content in result.content: for content in result.content:
if content.type == 'text': if content.type == 'text':
error_texts.append(content.text) error_texts.append(content.text)
raise Exception("\n".join(error_texts) if error_texts else "Unknown error from MCP tool") raise Exception('\n'.join(error_texts) if error_texts else 'Unknown error from MCP tool')
result_contents: list[provider_message.ContentElement] = [] result_contents: list[provider_message.ContentElement] = []
for content in result.content: for content in result.content:
if content.type == 'text': if content.type == 'text':
@@ -213,7 +213,7 @@ class RuntimeMCPSession:
elif content.type == 'resource': elif content.type == 'resource':
# TODO: Handle resource content # TODO: Handle resource content
pass pass
return result_contents return result_contents
func.__name__ = tool.name func.__name__ = tool.name
@@ -221,8 +221,8 @@ class RuntimeMCPSession:
self.functions.append( self.functions.append(
resource_tool.LLMTool( resource_tool.LLMTool(
name=tool.name, name=tool.name,
human_desc=tool.description or "", human_desc=tool.description or '',
description=tool.description or "", description=tool.description or '',
parameters=tool.inputSchema, parameters=tool.inputSchema,
func=func, func=func,
) )
@@ -338,13 +338,10 @@ class MCPLoader(loader.ToolLoader):
""" """
uuid_ = server_config.get('uuid') uuid_ = server_config.get('uuid')
if not uuid_: if not uuid_:
self.ap.logger.warning( self.ap.logger.warning('Server UUID is None for MCP server, maybe testing in the config page.')
'Server UUID is None for MCP server, maybe testing in the config page.'
)
uuid_ = str(uuid_module.uuid4()) uuid_ = str(uuid_module.uuid4())
server_config['uuid'] = uuid_ server_config['uuid'] = uuid_
name = server_config['name'] name = server_config['name']
uuid = server_config['uuid'] uuid = server_config['uuid']
mode = server_config['mode'] mode = server_config['mode']

View File

@@ -35,9 +35,9 @@ class Embedder(BaseService):
# get embeddings (batch size limit: 64 for OpenAI) # get embeddings (batch size limit: 64 for OpenAI)
MAX_BATCH_SIZE = 64 MAX_BATCH_SIZE = 64
embeddings_list: list[list[float]] = [] embeddings_list: list[list[float]] = []
for i in range(0, len(chunks), MAX_BATCH_SIZE): for i in range(0, len(chunks), MAX_BATCH_SIZE):
batch = chunks[i:i + MAX_BATCH_SIZE] batch = chunks[i : i + MAX_BATCH_SIZE]
batch_embeddings = await embedding_model.provider.requester.invoke_embedding( batch_embeddings = await embedding_model.provider.requester.invoke_embedding(
model=embedding_model, model=embedding_model,
input_text=batch, input_text=batch,

View File

@@ -55,12 +55,7 @@ class VectorDBManager:
user = pgvector_config.get('user', 'postgres') user = pgvector_config.get('user', 'postgres')
password = pgvector_config.get('password', 'postgres') password = pgvector_config.get('password', 'postgres')
self.vector_db = PgVectorDatabase( self.vector_db = PgVectorDatabase(
self.ap, self.ap, host=host, port=port, database=database, user=user, password=password
host=host,
port=port,
database=database,
user=user,
password=password
) )
self.ap.logger.info('Initialized pgvector database backend.') self.ap.logger.info('Initialized pgvector database backend.')

View File

@@ -10,7 +10,7 @@ from langbot.pkg.core import app
class MilvusVectorDatabase(VectorDatabase): class MilvusVectorDatabase(VectorDatabase):
"""Milvus vector database implementation""" """Milvus vector database implementation"""
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None, db_name: str = None): def __init__(self, ap: app.Application, uri: str = 'milvus.db', token: str = None, db_name: str = None):
"""Initialize Milvus vector database """Initialize Milvus vector database
Args: Args:
@@ -34,32 +34,32 @@ class MilvusVectorDatabase(VectorDatabase):
self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name) self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name)
else: else:
self.client = MilvusClient(uri=self.uri, db_name=self.db_name) self.client = MilvusClient(uri=self.uri, db_name=self.db_name)
self.ap.logger.info(f"Connected to Milvus at {self.uri}") self.ap.logger.info(f'Connected to Milvus at {self.uri}')
except Exception as e: except Exception as e:
self.ap.logger.error(f"Failed to connect to Milvus: {e}") self.ap.logger.error(f'Failed to connect to Milvus: {e}')
raise raise
@staticmethod @staticmethod
def _normalize_collection_name(collection: str) -> str: def _normalize_collection_name(collection: str) -> str:
"""Normalize collection name to comply with Milvus naming requirements. """Normalize collection name to comply with Milvus naming requirements.
Milvus requirements: Milvus requirements:
- First character must be an underscore or letter - First character must be an underscore or letter
- Can only contain numbers, letters and underscores - Can only contain numbers, letters and underscores
Args: Args:
collection: Original collection name (e.g., UUID with hyphens) collection: Original collection name (e.g., UUID with hyphens)
Returns: Returns:
Normalized collection name that complies with Milvus requirements Normalized collection name that complies with Milvus requirements
""" """
# Replace hyphens with underscores # Replace hyphens with underscores
normalized = collection.replace('-', '_') normalized = collection.replace('-', '_')
# If first character is not a letter or underscore, prepend 'kb_' # If first character is not a letter or underscore, prepend 'kb_'
if normalized and not (normalized[0].isalpha() or normalized[0] == '_'): if normalized and not (normalized[0].isalpha() or normalized[0] == '_'):
normalized = 'kb_' + normalized normalized = 'kb_' + normalized
return normalized return normalized
async def _ensure_vector_index(self, collection: str) -> None: async def _ensure_vector_index(self, collection: str) -> None:
@@ -70,15 +70,11 @@ class MilvusVectorDatabase(VectorDatabase):
""" """
index_params = IndexParams() index_params = IndexParams()
index_params.add_index( index_params.add_index(
field_name="vector", field_name='vector',
index_type="AUTOINDEX", index_type='AUTOINDEX',
metric_type="COSINE", metric_type='COSINE',
)
await asyncio.to_thread(
self.client.create_index,
collection_name=collection,
index_params=index_params
) )
await asyncio.to_thread(self.client.create_index, collection_name=collection, index_params=index_params)
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None): async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None):
"""Internal method to get or create a Milvus collection with proper configuration. """Internal method to get or create a Milvus collection with proper configuration.
@@ -89,14 +85,12 @@ class MilvusVectorDatabase(VectorDatabase):
""" """
# Normalize collection name for Milvus compatibility # Normalize collection name for Milvus compatibility
collection = self._normalize_collection_name(collection) collection = self._normalize_collection_name(collection)
if collection in self._collections: if collection in self._collections:
return collection return collection
# Check if collection exists # Check if collection exists
has_collection = await asyncio.to_thread( has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
self.client.has_collection, collection_name=collection
)
if not has_collection: if not has_collection:
# Default dimension if not specified (for backward compatibility) # Default dimension if not specified (for backward compatibility)
@@ -104,24 +98,26 @@ class MilvusVectorDatabase(VectorDatabase):
vector_size = 1536 vector_size = 1536
fields = [ fields = [
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255), FieldSchema(name='id', dtype=DataType.VARCHAR, is_primary=True, max_length=255),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size), FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255), FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=255),
FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255), FieldSchema(name='chunk_uuid', dtype=DataType.VARCHAR, max_length=255),
] ]
schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors") schema = CollectionSchema(fields=fields, description='LangBot knowledge base vectors')
await asyncio.to_thread( await asyncio.to_thread(
self.client.create_collection, self.client.create_collection,
collection_name=collection, collection_name=collection,
schema=schema, schema=schema,
metric_type="COSINE", metric_type='COSINE',
) )
await self._ensure_vector_index(collection) await self._ensure_vector_index(collection)
self.ap.logger.info(f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX") self.ap.logger.info(
f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX"
)
else: else:
# Ensure index exists for existing collection # Ensure index exists for existing collection
await self._ensure_index_if_missing(collection) await self._ensure_index_if_missing(collection)
@@ -137,11 +133,8 @@ class MilvusVectorDatabase(VectorDatabase):
collection: Normalized collection name collection: Normalized collection name
""" """
try: try:
indexes = await asyncio.to_thread( indexes = await asyncio.to_thread(self.client.list_indexes, collection_name=collection)
self.client.list_indexes, if 'vector' not in indexes:
collection_name=collection
)
if "vector" not in indexes:
await self._ensure_vector_index(collection) await self._ensure_vector_index(collection)
self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'") self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'")
except Exception as e: except Exception as e:
@@ -172,7 +165,7 @@ class MilvusVectorDatabase(VectorDatabase):
metadatas: List of metadata dictionaries for each vector metadatas: List of metadata dictionaries for each vector
""" """
collection = self._normalize_collection_name(collection) collection = self._normalize_collection_name(collection)
if not embeddings_list: if not embeddings_list:
return return
@@ -184,39 +177,30 @@ class MilvusVectorDatabase(VectorDatabase):
data = [] data = []
for i, vector_id in enumerate(ids): for i, vector_id in enumerate(ids):
entry = { entry = {
"id": vector_id, 'id': vector_id,
"vector": embeddings_list[i], 'vector': embeddings_list[i],
} }
# Add metadata fields # Add metadata fields
if metadatas and i < len(metadatas): if metadatas and i < len(metadatas):
metadata = metadatas[i] metadata = metadatas[i]
# Add common metadata fields # Add common metadata fields
if "text" in metadata: if 'text' in metadata:
entry["text"] = metadata["text"] entry['text'] = metadata['text']
if "file_id" in metadata: if 'file_id' in metadata:
entry["file_id"] = metadata["file_id"] entry['file_id'] = metadata['file_id']
if "uuid" in metadata: if 'uuid' in metadata:
entry["chunk_uuid"] = metadata["uuid"] entry['chunk_uuid'] = metadata['uuid']
data.append(entry) data.append(entry)
# Insert data into Milvus # Insert data into Milvus
await asyncio.to_thread( await asyncio.to_thread(self.client.insert, collection_name=collection, data=data)
self.client.insert,
collection_name=collection,
data=data
)
# Load collection for searching (Milvus requires this) # Load collection for searching (Milvus requires this)
await asyncio.to_thread( await asyncio.to_thread(self.client.load_collection, collection_name=collection)
self.client.load_collection,
collection_name=collection
)
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'") self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
async def search( async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
self, collection: str, query_embedding: list[float], k: int = 5
) -> Dict[str, Any]:
"""Search for similar vectors in Milvus collection """Search for similar vectors in Milvus collection
Args: Args:
@@ -231,10 +215,7 @@ class MilvusVectorDatabase(VectorDatabase):
await self.get_or_create_collection(collection) await self.get_or_create_collection(collection)
# Perform search # Perform search
search_params = { search_params = {'metric_type': 'COSINE', 'params': {}}
"metric_type": "COSINE",
"params": {}
}
results = await asyncio.to_thread( results = await asyncio.to_thread(
self.client.search, self.client.search,
@@ -242,7 +223,7 @@ class MilvusVectorDatabase(VectorDatabase):
data=[query_embedding], data=[query_embedding],
limit=k, limit=k,
search_params=search_params, search_params=search_params,
output_fields=["text", "file_id", "chunk_uuid"] output_fields=['text', 'file_id', 'chunk_uuid'],
) )
# Convert results to Chroma-compatible format # Convert results to Chroma-compatible format
@@ -253,30 +234,24 @@ class MilvusVectorDatabase(VectorDatabase):
if results and len(results) > 0: if results and len(results) > 0:
for hit in results[0]: for hit in results[0]:
ids.append(hit.get("id", "")) ids.append(hit.get('id', ''))
distances.append(hit.get("distance", 0.0)) distances.append(hit.get('distance', 0.0))
# Build metadata from entity fields # Build metadata from entity fields
entity = hit.get("entity", {}) entity = hit.get('entity', {})
metadata = {} metadata = {}
if "text" in entity: if 'text' in entity:
metadata["text"] = entity["text"] metadata['text'] = entity['text']
if "file_id" in entity: if 'file_id' in entity:
metadata["file_id"] = entity["file_id"] metadata['file_id'] = entity['file_id']
if "chunk_uuid" in entity: if 'chunk_uuid' in entity:
metadata["uuid"] = entity["chunk_uuid"] metadata['uuid'] = entity['chunk_uuid']
metadatas.append(metadata) metadatas.append(metadata)
# Return in Chroma-compatible format (nested lists) # Return in Chroma-compatible format (nested lists)
result = { result = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
"ids": [ids],
"distances": [distances],
"metadatas": [metadatas]
}
self.ap.logger.info( self.ap.logger.info(f"Milvus search in '{collection}' returned {len(ids)} results")
f"Milvus search in '{collection}' returned {len(ids)} results"
)
return result return result
async def delete_by_file_id(self, collection: str, file_id: str) -> None: async def delete_by_file_id(self, collection: str, file_id: str) -> None:
@@ -290,14 +265,8 @@ class MilvusVectorDatabase(VectorDatabase):
await self.get_or_create_collection(collection) await self.get_or_create_collection(collection)
# Delete entities matching the file_id # Delete entities matching the file_id
await asyncio.to_thread( await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
self.client.delete, self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
collection_name=collection,
filter=f'file_id == "{file_id}"'
)
self.ap.logger.info(
f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}"
)
async def delete_collection(self, collection: str): async def delete_collection(self, collection: str):
"""Delete a Milvus collection """Delete a Milvus collection
@@ -306,18 +275,14 @@ class MilvusVectorDatabase(VectorDatabase):
collection: Collection name to delete collection: Collection name to delete
""" """
collection = self._normalize_collection_name(collection) collection = self._normalize_collection_name(collection)
self._collections.discard(collection) self._collections.discard(collection)
# Check if collection exists before attempting deletion # Check if collection exists before attempting deletion
has_collection = await asyncio.to_thread( has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
self.client.has_collection, collection_name=collection
)
if has_collection: if has_collection:
await asyncio.to_thread( await asyncio.to_thread(self.client.drop_collection, collection_name=collection)
self.client.drop_collection, collection_name=collection
)
self.ap.logger.info(f"Deleted Milvus collection '{collection}'") self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
else: else:
self.ap.logger.warning(f"Milvus collection '{collection}' not found") self.ap.logger.warning(f"Milvus collection '{collection}' not found")

View File

@@ -9,27 +9,28 @@ from typing import Any
def _apply_env_overrides_to_config(cfg: dict) -> dict: def _apply_env_overrides_to_config(cfg: dict) -> dict:
"""Apply environment variable overrides to data/config.yaml """Apply environment variable overrides to data/config.yaml
Environment variables should be uppercase and use __ (double underscore) Environment variables should be uppercase and use __ (double underscore)
to represent nested keys. For example: to represent nested keys. For example:
- CONCURRENCY__PIPELINE overrides concurrency.pipeline - CONCURRENCY__PIPELINE overrides concurrency.pipeline
- PLUGIN__RUNTIME_WS_URL overrides plugin.runtime_ws_url - PLUGIN__RUNTIME_WS_URL overrides plugin.runtime_ws_url
Arrays and dict types are ignored. Arrays and dict types are ignored.
Args: Args:
cfg: Configuration dictionary cfg: Configuration dictionary
Returns: Returns:
Updated configuration dictionary Updated configuration dictionary
""" """
def convert_value(value: str, original_value: Any) -> Any: def convert_value(value: str, original_value: Any) -> Any:
"""Convert string value to appropriate type based on original value """Convert string value to appropriate type based on original value
Args: Args:
value: String value from environment variable value: String value from environment variable
original_value: Original value to infer type from original_value: Original value to infer type from
Returns: Returns:
Converted value (falls back to string if conversion fails) Converted value (falls back to string if conversion fails)
""" """
@@ -49,7 +50,7 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
return value return value
else: else:
return value return value
# Process environment variables # Process environment variables
for env_key, env_value in os.environ.items(): for env_key, env_value in os.environ.items():
# Check if the environment variable is uppercase and contains __ # Check if the environment variable is uppercase and contains __
@@ -57,18 +58,18 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
continue continue
if '__' not in env_key: if '__' not in env_key:
continue continue
# Convert environment variable name to config path # Convert environment variable name to config path
# e.g., CONCURRENCY__PIPELINE -> ['concurrency', 'pipeline'] # e.g., CONCURRENCY__PIPELINE -> ['concurrency', 'pipeline']
keys = [key.lower() for key in env_key.split('__')] keys = [key.lower() for key in env_key.split('__')]
# Navigate to the target value and validate the path # Navigate to the target value and validate the path
current = cfg current = cfg
for i, key in enumerate(keys): for i, key in enumerate(keys):
if not isinstance(current, dict) or key not in current: if not isinstance(current, dict) or key not in current:
break break
if i == len(keys) - 1: if i == len(keys) - 1:
# At the final key - check if it's a scalar value # At the final key - check if it's a scalar value
if isinstance(current[key], (dict, list)): if isinstance(current[key], (dict, list)):
@@ -81,248 +82,182 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
else: else:
# Navigate deeper # Navigate deeper
current = current[key] current = current[key]
return cfg return cfg
class TestEnvOverrides: class TestEnvOverrides:
"""Test environment variable override functionality""" """Test environment variable override functionality"""
def test_simple_string_override(self): def test_simple_string_override(self):
"""Test overriding a simple string value""" """Test overriding a simple string value"""
cfg = { cfg = {'api': {'port': 5300}}
'api': {
'port': 5300
}
}
# Set environment variable # Set environment variable
os.environ['API__PORT'] = '8080' os.environ['API__PORT'] = '8080'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
assert result['api']['port'] == 8080 assert result['api']['port'] == 8080
# Cleanup # Cleanup
del os.environ['API__PORT'] del os.environ['API__PORT']
def test_nested_key_override(self): def test_nested_key_override(self):
"""Test overriding nested keys with __ delimiter""" """Test overriding nested keys with __ delimiter"""
cfg = { cfg = {'concurrency': {'pipeline': 20, 'session': 1}}
'concurrency': {
'pipeline': 20,
'session': 1
}
}
os.environ['CONCURRENCY__PIPELINE'] = '50' os.environ['CONCURRENCY__PIPELINE'] = '50'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
assert result['concurrency']['pipeline'] == 50 assert result['concurrency']['pipeline'] == 50
assert result['concurrency']['session'] == 1 # Unchanged assert result['concurrency']['session'] == 1 # Unchanged
del os.environ['CONCURRENCY__PIPELINE'] del os.environ['CONCURRENCY__PIPELINE']
def test_deep_nested_override(self): def test_deep_nested_override(self):
"""Test overriding deeply nested keys""" """Test overriding deeply nested keys"""
cfg = { cfg = {'system': {'jwt': {'expire': 604800, 'secret': ''}}}
'system': {
'jwt': {
'expire': 604800,
'secret': ''
}
}
}
os.environ['SYSTEM__JWT__EXPIRE'] = '86400' os.environ['SYSTEM__JWT__EXPIRE'] = '86400'
os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key' os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
assert result['system']['jwt']['expire'] == 86400 assert result['system']['jwt']['expire'] == 86400
assert result['system']['jwt']['secret'] == 'my_secret_key' assert result['system']['jwt']['secret'] == 'my_secret_key'
del os.environ['SYSTEM__JWT__EXPIRE'] del os.environ['SYSTEM__JWT__EXPIRE']
del os.environ['SYSTEM__JWT__SECRET'] del os.environ['SYSTEM__JWT__SECRET']
def test_underscore_in_key(self): def test_underscore_in_key(self):
"""Test keys with underscores like runtime_ws_url""" """Test keys with underscores like runtime_ws_url"""
cfg = { cfg = {'plugin': {'enable': True, 'runtime_ws_url': 'ws://localhost:5400/control/ws'}}
'plugin': {
'enable': True,
'runtime_ws_url': 'ws://localhost:5400/control/ws'
}
}
os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws' os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
assert result['plugin']['runtime_ws_url'] == 'ws://newhost:6000/ws' assert result['plugin']['runtime_ws_url'] == 'ws://newhost:6000/ws'
del os.environ['PLUGIN__RUNTIME_WS_URL'] del os.environ['PLUGIN__RUNTIME_WS_URL']
def test_boolean_conversion(self): def test_boolean_conversion(self):
"""Test boolean value conversion""" """Test boolean value conversion"""
cfg = { cfg = {'plugin': {'enable': True, 'enable_marketplace': False}}
'plugin': {
'enable': True,
'enable_marketplace': False
}
}
os.environ['PLUGIN__ENABLE'] = 'false' os.environ['PLUGIN__ENABLE'] = 'false'
os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true' os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
assert result['plugin']['enable'] is False assert result['plugin']['enable'] is False
assert result['plugin']['enable_marketplace'] is True assert result['plugin']['enable_marketplace'] is True
del os.environ['PLUGIN__ENABLE'] del os.environ['PLUGIN__ENABLE']
del os.environ['PLUGIN__ENABLE_MARKETPLACE'] del os.environ['PLUGIN__ENABLE_MARKETPLACE']
def test_ignore_dict_type(self): def test_ignore_dict_type(self):
"""Test that dict types are ignored""" """Test that dict types are ignored"""
cfg = { cfg = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
'database': {
'use': 'sqlite',
'sqlite': {
'path': 'data/langbot.db'
}
}
}
# Try to override a dict value - should be ignored # Try to override a dict value - should be ignored
os.environ['DATABASE__SQLITE'] = 'new_value' os.environ['DATABASE__SQLITE'] = 'new_value'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
# Should remain a dict, not overridden # Should remain a dict, not overridden
assert isinstance(result['database']['sqlite'], dict) assert isinstance(result['database']['sqlite'], dict)
assert result['database']['sqlite']['path'] == 'data/langbot.db' assert result['database']['sqlite']['path'] == 'data/langbot.db'
del os.environ['DATABASE__SQLITE'] del os.environ['DATABASE__SQLITE']
def test_ignore_list_type(self): def test_ignore_list_type(self):
"""Test that list/array types are ignored""" """Test that list/array types are ignored"""
cfg = { cfg = {'admins': ['admin1', 'admin2'], 'command': {'enable': True, 'prefix': ['!', '']}}
'admins': ['admin1', 'admin2'],
'command': {
'enable': True,
'prefix': ['!', '']
}
}
# Try to override list values - should be ignored # Try to override list values - should be ignored
os.environ['ADMINS'] = 'admin3' os.environ['ADMINS'] = 'admin3'
os.environ['COMMAND__PREFIX'] = '?' os.environ['COMMAND__PREFIX'] = '?'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
# Should remain lists, not overridden # Should remain lists, not overridden
assert isinstance(result['admins'], list) assert isinstance(result['admins'], list)
assert result['admins'] == ['admin1', 'admin2'] assert result['admins'] == ['admin1', 'admin2']
assert isinstance(result['command']['prefix'], list) assert isinstance(result['command']['prefix'], list)
assert result['command']['prefix'] == ['!', ''] assert result['command']['prefix'] == ['!', '']
del os.environ['ADMINS'] del os.environ['ADMINS']
del os.environ['COMMAND__PREFIX'] del os.environ['COMMAND__PREFIX']
def test_lowercase_env_var_ignored(self): def test_lowercase_env_var_ignored(self):
"""Test that lowercase environment variables are ignored""" """Test that lowercase environment variables are ignored"""
cfg = { cfg = {'api': {'port': 5300}}
'api': {
'port': 5300
}
}
os.environ['api__port'] = '8080' os.environ['api__port'] = '8080'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
# Should not be overridden # Should not be overridden
assert result['api']['port'] == 5300 assert result['api']['port'] == 5300
del os.environ['api__port'] del os.environ['api__port']
def test_no_double_underscore_ignored(self): def test_no_double_underscore_ignored(self):
"""Test that env vars without __ are ignored""" """Test that env vars without __ are ignored"""
cfg = { cfg = {'api': {'port': 5300}}
'api': {
'port': 5300
}
}
os.environ['APIPORT'] = '8080' os.environ['APIPORT'] = '8080'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
# Should not be overridden # Should not be overridden
assert result['api']['port'] == 5300 assert result['api']['port'] == 5300
del os.environ['APIPORT'] del os.environ['APIPORT']
def test_nonexistent_key_ignored(self): def test_nonexistent_key_ignored(self):
"""Test that env vars for non-existent keys are ignored""" """Test that env vars for non-existent keys are ignored"""
cfg = { cfg = {'api': {'port': 5300}}
'api': {
'port': 5300
}
}
os.environ['API__NONEXISTENT'] = 'value' os.environ['API__NONEXISTENT'] = 'value'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
# Should not create new key # Should not create new key
assert 'nonexistent' not in result['api'] assert 'nonexistent' not in result['api']
del os.environ['API__NONEXISTENT'] del os.environ['API__NONEXISTENT']
def test_integer_conversion(self): def test_integer_conversion(self):
"""Test integer value conversion""" """Test integer value conversion"""
cfg = { cfg = {'concurrency': {'pipeline': 20}}
'concurrency': {
'pipeline': 20
}
}
os.environ['CONCURRENCY__PIPELINE'] = '100' os.environ['CONCURRENCY__PIPELINE'] = '100'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
assert result['concurrency']['pipeline'] == 100 assert result['concurrency']['pipeline'] == 100
assert isinstance(result['concurrency']['pipeline'], int) assert isinstance(result['concurrency']['pipeline'], int)
del os.environ['CONCURRENCY__PIPELINE'] del os.environ['CONCURRENCY__PIPELINE']
def test_multiple_overrides(self): def test_multiple_overrides(self):
"""Test multiple environment variable overrides at once""" """Test multiple environment variable overrides at once"""
cfg = { cfg = {'api': {'port': 5300}, 'concurrency': {'pipeline': 20, 'session': 1}, 'plugin': {'enable': False}}
'api': {
'port': 5300
},
'concurrency': {
'pipeline': 20,
'session': 1
},
'plugin': {
'enable': False
}
}
os.environ['API__PORT'] = '8080' os.environ['API__PORT'] = '8080'
os.environ['CONCURRENCY__PIPELINE'] = '50' os.environ['CONCURRENCY__PIPELINE'] = '50'
os.environ['PLUGIN__ENABLE'] = 'true' os.environ['PLUGIN__ENABLE'] = 'true'
result = _apply_env_overrides_to_config(cfg) result = _apply_env_overrides_to_config(cfg)
assert result['api']['port'] == 8080 assert result['api']['port'] == 8080
assert result['concurrency']['pipeline'] == 50 assert result['concurrency']['pipeline'] == 50
assert result['plugin']['enable'] is True assert result['plugin']['enable'] is True
del os.environ['API__PORT'] del os.environ['API__PORT']
del os.environ['CONCURRENCY__PIPELINE'] del os.environ['CONCURRENCY__PIPELINE']
del os.environ['PLUGIN__ENABLE'] del os.environ['PLUGIN__ENABLE']

View File

@@ -1,6 +1,5 @@
"""Test plugin list filtering by component kinds.""" """Test plugin list filtering by component kinds."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@@ -31,16 +30,7 @@ async def test_plugin_list_filter_by_component_kinds():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'Tool',
'metadata': {'name': 'tool1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -53,15 +43,8 @@ async def test_plugin_list_filter_by_component_kinds():
} }
}, },
'components': [ 'components': [
{ {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
'manifest': { ],
'manifest': {
'kind': 'KnowledgeRetriever',
'metadata': {'name': 'retriever1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -73,16 +56,7 @@ async def test_plugin_list_filter_by_component_kinds():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'Command', 'metadata': {'name': 'cmd1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'Command',
'metadata': {'name': 'cmd1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -94,16 +68,7 @@ async def test_plugin_list_filter_by_component_kinds():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'EventListener', 'metadata': {'name': 'listener1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'EventListener',
'metadata': {'name': 'listener1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -116,23 +81,9 @@ async def test_plugin_list_filter_by_component_kinds():
} }
}, },
'components': [ 'components': [
{ {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever2'}}}},
'manifest': { {'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool2'}}}},
'manifest': { ],
'kind': 'KnowledgeRetriever',
'metadata': {'name': 'retriever2'}
}
}
},
{
'manifest': {
'manifest': {
'kind': 'Tool',
'metadata': {'name': 'tool2'}
}
}
}
]
}, },
] ]
@@ -187,16 +138,7 @@ async def test_plugin_list_filter_no_filter():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'Tool',
'metadata': {'name': 'tool1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -209,15 +151,8 @@ async def test_plugin_list_filter_no_filter():
} }
}, },
'components': [ 'components': [
{ {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
'manifest': { ],
'manifest': {
'kind': 'KnowledgeRetriever',
'metadata': {'name': 'retriever1'}
}
}
}
]
}, },
] ]
@@ -267,15 +202,8 @@ async def test_plugin_list_filter_empty_result():
} }
}, },
'components': [ 'components': [
{ {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
'manifest': { ],
'manifest': {
'kind': 'KnowledgeRetriever',
'metadata': {'name': 'retriever1'}
}
}
}
]
}, },
] ]
@@ -321,16 +249,7 @@ async def test_plugin_list_filter_plugin_without_components():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'Tool',
'metadata': {'name': 'tool1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -342,7 +261,7 @@ async def test_plugin_list_filter_plugin_without_components():
} }
} }
}, },
'components': [] 'components': [],
}, },
] ]