mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: add GitHub Actions workflow for linting with Ruff (#1929)
* feat: add GitHub Actions workflow for linting with Ruff * refactor: rename lint job and add formatting step to Ruff workflow * chore: run ruff format * chore: rename Ruff lint job to 'Lint' and add frontend linting workflow
This commit is contained in:
committed by
GitHub
parent
e60cb6ad0e
commit
fc6e414be4
60
.github/workflows/lint.yml
vendored
Normal file
60
.github/workflows/lint.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
- dev
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
name: Ruff Lint & Format
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --dev
|
||||
|
||||
- name: Run ruff check
|
||||
run: uv run ruff check src
|
||||
|
||||
- name: Run ruff format
|
||||
run: uv run ruff format src --check
|
||||
|
||||
frontend:
|
||||
name: Frontend Lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '25'
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: pnpm install
|
||||
|
||||
- name: Run lint
|
||||
working-directory: web
|
||||
run: pnpm lint
|
||||
@@ -85,7 +85,6 @@ class QQOfficialClient:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
|
||||
body = await req.get_data()
|
||||
|
||||
print(f'[QQ Official] Received request, body length: {len(body)}')
|
||||
@@ -96,7 +95,6 @@ class QQOfficialClient:
|
||||
|
||||
payload = json.loads(body)
|
||||
|
||||
|
||||
if payload.get('op') == 13:
|
||||
validation_data = payload.get('d')
|
||||
if not validation_data:
|
||||
@@ -276,21 +274,21 @@ class QQOfficialClient:
|
||||
seed = bot_secret
|
||||
while len(seed) < target_size:
|
||||
seed *= 2
|
||||
return seed[:target_size].encode("utf-8")
|
||||
return seed[:target_size].encode('utf-8')
|
||||
|
||||
async def verify(self, validation_payload: dict):
|
||||
seed = await self.repeat_seed(self.secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
|
||||
event_ts = validation_payload.get("event_ts", "")
|
||||
plain_token = validation_payload.get("plain_token", "")
|
||||
event_ts = validation_payload.get('event_ts', '')
|
||||
plain_token = validation_payload.get('plain_token', '')
|
||||
msg = event_ts + plain_token
|
||||
|
||||
# sign
|
||||
signature = private_key.sign(msg.encode()).hex()
|
||||
|
||||
response = {
|
||||
"plain_token": plain_token,
|
||||
"signature": signature,
|
||||
'plain_token': plain_token,
|
||||
'signature': signature,
|
||||
}
|
||||
return response
|
||||
|
||||
@@ -36,7 +36,12 @@ class WecomBotEvent(dict):
|
||||
"""
|
||||
用户名称
|
||||
"""
|
||||
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid
|
||||
return (
|
||||
self.get('username', '')
|
||||
or self.get('from', {}).get('alias', '')
|
||||
or self.get('from', {}).get('name', '')
|
||||
or self.userid
|
||||
)
|
||||
|
||||
@property
|
||||
def chatname(self) -> str:
|
||||
@@ -121,7 +126,7 @@ class WecomBotEvent(dict):
|
||||
消息id
|
||||
"""
|
||||
return self.get('msgid', '')
|
||||
|
||||
|
||||
@property
|
||||
def ai_bot_id(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -30,7 +30,6 @@ class WebhookRouterGroup(group.RouterGroup):
|
||||
适配器返回的响应
|
||||
"""
|
||||
try:
|
||||
|
||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
|
||||
if not runtime_bot:
|
||||
@@ -39,11 +38,9 @@ class WebhookRouterGroup(group.RouterGroup):
|
||||
if not runtime_bot.enable:
|
||||
return quart.jsonify({'error': 'Bot is disabled'}), 403
|
||||
|
||||
|
||||
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
|
||||
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
|
||||
|
||||
|
||||
response = await runtime_bot.adapter.handle_unified_webhook(
|
||||
bot_uuid=bot_uuid,
|
||||
path=path,
|
||||
|
||||
@@ -59,7 +59,16 @@ class BotService:
|
||||
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
||||
|
||||
# Webhook URL for unified webhook adapters (independent of bot running state)
|
||||
if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack', 'wecomcs', 'LINE', 'lark']:
|
||||
if persistence_bot['adapter'] in [
|
||||
'wecom',
|
||||
'wecombot',
|
||||
'officialaccount',
|
||||
'qqofficial',
|
||||
'slack',
|
||||
'wecomcs',
|
||||
'LINE',
|
||||
'lark',
|
||||
]:
|
||||
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||
webhook_url = f'/bots/{bot_uuid}'
|
||||
adapter_runtime_values['webhook_url'] = webhook_url
|
||||
|
||||
@@ -34,7 +34,6 @@ from .. import taskmgr
|
||||
from ...telemetry import telemetry as telemetry_module
|
||||
|
||||
|
||||
|
||||
@stage.stage_class('BuildAppStage')
|
||||
class BuildAppStage(stage.BootingStage):
|
||||
"""Build LangBot application"""
|
||||
|
||||
@@ -244,7 +244,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
|
||||
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
|
||||
|
||||
|
||||
if message.message_type == 'text':
|
||||
element_list = []
|
||||
|
||||
@@ -310,7 +309,11 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
]
|
||||
elif message.message_type == 'audio':
|
||||
message_content['content'] = [
|
||||
{'tag': 'audio', 'file_key': message_content['file_key'], "duration": message_content.get('duration',0)}
|
||||
{
|
||||
'tag': 'audio',
|
||||
'file_key': message_content['file_key'],
|
||||
'duration': message_content.get('duration', 0),
|
||||
}
|
||||
]
|
||||
|
||||
for ele in message_content['content']:
|
||||
@@ -367,12 +370,9 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
audio_bytes = response.file.read()
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode()
|
||||
|
||||
|
||||
# Get content type from response headers
|
||||
content_type = response.raw.headers.get('content-type', 'audio/mpeg')
|
||||
|
||||
|
||||
|
||||
mime_main = content_type.split(';')[0].strip()
|
||||
ext = mimetypes.guess_extension(mime_main) or '.bin'
|
||||
temp_dir = tempfile.gettempdir()
|
||||
@@ -418,7 +418,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
file_bytes = response.file.read()
|
||||
file_base64 = base64.b64encode(file_bytes).decode()
|
||||
|
||||
|
||||
file_format = response.raw.headers['content-type']
|
||||
|
||||
file_size = len(file_bytes)
|
||||
@@ -453,7 +452,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
return platform_message.MessageChain(lb_msg_list)
|
||||
|
||||
|
||||
|
||||
@@ -18,52 +18,52 @@ import langbot_plugin.api.entities.builtin.platform.entities as platform_entitie
|
||||
def split_string_by_bytes(text, limit=2048, encoding='utf-8'):
|
||||
"""
|
||||
Splits a string into a list of strings, where each part is at most 'limit' bytes.
|
||||
|
||||
|
||||
Args:
|
||||
text (str): The original string to split.
|
||||
limit (int): The maximum byte size for each split part.
|
||||
encoding (str): The encoding to use (default is 'utf-8').
|
||||
|
||||
|
||||
Returns:
|
||||
list: A list of split strings.
|
||||
"""
|
||||
# 1. Encode the entire string into bytes
|
||||
bytes_data = text.encode(encoding)
|
||||
total_len = len(bytes_data)
|
||||
|
||||
|
||||
parts = []
|
||||
start = 0
|
||||
|
||||
|
||||
while start < total_len:
|
||||
# 2. Determine the end index for the current chunk
|
||||
# It shouldn't exceed the total length
|
||||
end = min(start + limit, total_len)
|
||||
|
||||
|
||||
# 3. Slice the byte array
|
||||
chunk = bytes_data[start:end]
|
||||
|
||||
|
||||
# 4. Attempt to decode the chunk
|
||||
# Use errors='ignore' to drop any partial bytes at the end of the chunk
|
||||
# (e.g., if a 3-byte character was cut after the 2nd byte)
|
||||
part_str = chunk.decode(encoding, errors='ignore')
|
||||
|
||||
|
||||
# 5. Calculate the actual byte length of the successfully decoded string
|
||||
# This tells us exactly where the valid character boundary ended
|
||||
part_bytes = part_str.encode(encoding)
|
||||
part_len = len(part_bytes)
|
||||
|
||||
|
||||
# Safety check: Prevent infinite loop if limit is too small (e.g., limit=1 for a Chinese char)
|
||||
if part_len == 0 and end < total_len:
|
||||
# Force advance by 1 byte to consume the un-decodable byte or raise error
|
||||
# Here we just treat it as a part to avoid stuck loops, though it might be invalid
|
||||
start += 1
|
||||
start += 1
|
||||
continue
|
||||
|
||||
parts.append(part_str)
|
||||
|
||||
|
||||
# 6. Move the start pointer by the actual length consumed
|
||||
start += part_len
|
||||
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
@@ -75,13 +75,15 @@ class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||
for msg in message_chain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
chunks = split_string_by_bytes(msg.text)
|
||||
content_list.extend([
|
||||
{
|
||||
'type': 'text',
|
||||
'content': chunk,
|
||||
}
|
||||
for chunk in chunks
|
||||
])
|
||||
content_list.extend(
|
||||
[
|
||||
{
|
||||
'type': 'text',
|
||||
'content': chunk,
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
)
|
||||
elif type(msg) is platform_message.Image:
|
||||
content_list.append(
|
||||
{
|
||||
|
||||
@@ -51,9 +51,10 @@ class SeekDBEmbedding(requester.ProviderAPIRequester):
|
||||
await self.initialize()
|
||||
|
||||
if self._embedding_function is None:
|
||||
raise RuntimeError("SeekDB embedding function initialization failed")
|
||||
raise RuntimeError('SeekDB embedding function initialization failed')
|
||||
|
||||
return self._embedding_function(input_text)
|
||||
except Exception as e:
|
||||
from .. import errors
|
||||
|
||||
raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}')
|
||||
|
||||
@@ -218,10 +218,14 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
parameters = {}
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query)
|
||||
|
||||
|
||||
# Handle return value content
|
||||
tool_content = None
|
||||
if isinstance(func_ret, list) and len(func_ret) > 0 and isinstance(func_ret[0], provider_message.ContentElement):
|
||||
if (
|
||||
isinstance(func_ret, list)
|
||||
and len(func_ret) > 0
|
||||
and isinstance(func_ret[0], provider_message.ContentElement)
|
||||
):
|
||||
tool_content = func_ret
|
||||
else:
|
||||
tool_content = json.dumps(func_ret, ensure_ascii=False)
|
||||
|
||||
@@ -68,15 +68,16 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
return plain_text
|
||||
|
||||
async def _process_stream_response(self, response: aiohttp.ClientResponse) -> typing.AsyncGenerator[
|
||||
provider_message.Message, None]:
|
||||
async def _process_stream_response(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况"""
|
||||
full_content = ""
|
||||
full_content = ''
|
||||
chunk_idx = 0
|
||||
is_final = False
|
||||
message_idx = 0
|
||||
|
||||
buffer = ""
|
||||
buffer = ''
|
||||
decoder = json.JSONDecoder()
|
||||
|
||||
async for raw_chunk in response.content.iter_chunked(1024):
|
||||
@@ -129,7 +130,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
preview = chunk_str[:200]
|
||||
except Exception:
|
||||
preview = '<unavailable>'
|
||||
self.ap.logger.warning(f"Failed to process chunk: {e}; chunk preview: {preview}")
|
||||
self.ap.logger.warning(f'Failed to process chunk: {e}; chunk preview: {preview}')
|
||||
|
||||
# 流结束后,尝试解析残余 buffer
|
||||
if buffer:
|
||||
@@ -151,7 +152,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
)
|
||||
except Exception as e:
|
||||
preview = buffer[:200]
|
||||
self.ap.logger.warning(f"Failed to parse remaining buffer: {e}; buffer preview: {preview}")
|
||||
self.ap.logger.warning(f'Failed to parse remaining buffer: {e}; buffer preview: {preview}')
|
||||
|
||||
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用n8n webhook"""
|
||||
@@ -165,7 +166,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
# 准备请求数据
|
||||
payload = {
|
||||
# 基本消息内容
|
||||
'chatInput' :plain_text, # 考虑到之前用户直接用的message model这里添加新键
|
||||
'chatInput': plain_text, # 考虑到之前用户直接用的message model这里添加新键
|
||||
'message': plain_text,
|
||||
'user_message_text': plain_text,
|
||||
'conversation_id': query.session.using_conversation.uuid,
|
||||
@@ -217,57 +218,49 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
# 调用webhook
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if is_stream:
|
||||
# 流式请求
|
||||
async with session.post(
|
||||
self.webhook_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
timeout=self.timeout
|
||||
) as response:
|
||||
if is_stream:
|
||||
# 流式请求
|
||||
async with session.post(
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
|
||||
# 处理流式响应
|
||||
async for chunk in self._process_stream_response(response):
|
||||
yield chunk
|
||||
else:
|
||||
async with session.post(
|
||||
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||
) as response:
|
||||
try:
|
||||
async for chunk in self._process_stream_response(response):
|
||||
output_content = chunk.content if chunk.is_final else ''
|
||||
except:
|
||||
# 非流式请求(保持原有逻辑)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
|
||||
# 处理流式响应
|
||||
async for chunk in self._process_stream_response(response):
|
||||
yield chunk
|
||||
else:
|
||||
async with session.post(
|
||||
self.webhook_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
auth=auth,
|
||||
timeout=self.timeout
|
||||
) as response:
|
||||
try:
|
||||
async for chunk in self._process_stream_response(response):
|
||||
output_content = chunk.content if chunk.is_final else ''
|
||||
except:
|
||||
# 非流式请求(保持原有逻辑)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
raise Exception(f'n8n webhook call failed: {response.status}, {error_text}')
|
||||
# 解析响应
|
||||
response_data = await response.json()
|
||||
self.ap.logger.debug(f'n8n webhook response: {response_data}')
|
||||
|
||||
# 解析响应
|
||||
response_data = await response.json()
|
||||
self.ap.logger.debug(f'n8n webhook response: {response_data}')
|
||||
# 从响应中提取输出
|
||||
if self.output_key in response_data:
|
||||
output_content = response_data[self.output_key]
|
||||
else:
|
||||
# 如果没有指定的输出键,则使用整个响应
|
||||
output_content = json.dumps(response_data, ensure_ascii=False)
|
||||
|
||||
# 从响应中提取输出
|
||||
if self.output_key in response_data:
|
||||
output_content = response_data[self.output_key]
|
||||
else:
|
||||
# 如果没有指定的输出键,则使用整个响应
|
||||
output_content = json.dumps(response_data, ensure_ascii=False)
|
||||
|
||||
# 返回消息
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=output_content,
|
||||
)
|
||||
# 返回消息
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=output_content,
|
||||
)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
|
||||
raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
|
||||
@@ -275,4 +268,4 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行请求"""
|
||||
async for msg in self._call_webhook(query):
|
||||
yield msg
|
||||
yield msg
|
||||
|
||||
@@ -194,7 +194,7 @@ class RuntimeMCPSession:
|
||||
|
||||
async def func(*, _tool=tool, **kwargs):
|
||||
if not self.session:
|
||||
raise Exception("MCP session is not connected")
|
||||
raise Exception('MCP session is not connected')
|
||||
|
||||
result = await self.session.call_tool(_tool.name, kwargs)
|
||||
if result.isError:
|
||||
@@ -202,8 +202,8 @@ class RuntimeMCPSession:
|
||||
for content in result.content:
|
||||
if content.type == 'text':
|
||||
error_texts.append(content.text)
|
||||
raise Exception("\n".join(error_texts) if error_texts else "Unknown error from MCP tool")
|
||||
|
||||
raise Exception('\n'.join(error_texts) if error_texts else 'Unknown error from MCP tool')
|
||||
|
||||
result_contents: list[provider_message.ContentElement] = []
|
||||
for content in result.content:
|
||||
if content.type == 'text':
|
||||
@@ -213,7 +213,7 @@ class RuntimeMCPSession:
|
||||
elif content.type == 'resource':
|
||||
# TODO: Handle resource content
|
||||
pass
|
||||
|
||||
|
||||
return result_contents
|
||||
|
||||
func.__name__ = tool.name
|
||||
@@ -221,8 +221,8 @@ class RuntimeMCPSession:
|
||||
self.functions.append(
|
||||
resource_tool.LLMTool(
|
||||
name=tool.name,
|
||||
human_desc=tool.description or "",
|
||||
description=tool.description or "",
|
||||
human_desc=tool.description or '',
|
||||
description=tool.description or '',
|
||||
parameters=tool.inputSchema,
|
||||
func=func,
|
||||
)
|
||||
@@ -338,13 +338,10 @@ class MCPLoader(loader.ToolLoader):
|
||||
"""
|
||||
uuid_ = server_config.get('uuid')
|
||||
if not uuid_:
|
||||
self.ap.logger.warning(
|
||||
'Server UUID is None for MCP server, maybe testing in the config page.'
|
||||
)
|
||||
self.ap.logger.warning('Server UUID is None for MCP server, maybe testing in the config page.')
|
||||
uuid_ = str(uuid_module.uuid4())
|
||||
server_config['uuid'] = uuid_
|
||||
|
||||
|
||||
name = server_config['name']
|
||||
uuid = server_config['uuid']
|
||||
mode = server_config['mode']
|
||||
|
||||
@@ -35,9 +35,9 @@ class Embedder(BaseService):
|
||||
# get embeddings (batch size limit: 64 for OpenAI)
|
||||
MAX_BATCH_SIZE = 64
|
||||
embeddings_list: list[list[float]] = []
|
||||
|
||||
|
||||
for i in range(0, len(chunks), MAX_BATCH_SIZE):
|
||||
batch = chunks[i:i + MAX_BATCH_SIZE]
|
||||
batch = chunks[i : i + MAX_BATCH_SIZE]
|
||||
batch_embeddings = await embedding_model.provider.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=batch,
|
||||
|
||||
@@ -55,12 +55,7 @@ class VectorDBManager:
|
||||
user = pgvector_config.get('user', 'postgres')
|
||||
password = pgvector_config.get('password', 'postgres')
|
||||
self.vector_db = PgVectorDatabase(
|
||||
self.ap,
|
||||
host=host,
|
||||
port=port,
|
||||
database=database,
|
||||
user=user,
|
||||
password=password
|
||||
self.ap, host=host, port=port, database=database, user=user, password=password
|
||||
)
|
||||
self.ap.logger.info('Initialized pgvector database backend.')
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from langbot.pkg.core import app
|
||||
class MilvusVectorDatabase(VectorDatabase):
|
||||
"""Milvus vector database implementation"""
|
||||
|
||||
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None, db_name: str = None):
|
||||
def __init__(self, ap: app.Application, uri: str = 'milvus.db', token: str = None, db_name: str = None):
|
||||
"""Initialize Milvus vector database
|
||||
|
||||
Args:
|
||||
@@ -34,32 +34,32 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name)
|
||||
else:
|
||||
self.client = MilvusClient(uri=self.uri, db_name=self.db_name)
|
||||
self.ap.logger.info(f"Connected to Milvus at {self.uri}")
|
||||
self.ap.logger.info(f'Connected to Milvus at {self.uri}')
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Failed to connect to Milvus: {e}")
|
||||
self.ap.logger.error(f'Failed to connect to Milvus: {e}')
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _normalize_collection_name(collection: str) -> str:
|
||||
"""Normalize collection name to comply with Milvus naming requirements.
|
||||
|
||||
|
||||
Milvus requirements:
|
||||
- First character must be an underscore or letter
|
||||
- Can only contain numbers, letters and underscores
|
||||
|
||||
|
||||
Args:
|
||||
collection: Original collection name (e.g., UUID with hyphens)
|
||||
|
||||
|
||||
Returns:
|
||||
Normalized collection name that complies with Milvus requirements
|
||||
"""
|
||||
# Replace hyphens with underscores
|
||||
normalized = collection.replace('-', '_')
|
||||
|
||||
|
||||
# If first character is not a letter or underscore, prepend 'kb_'
|
||||
if normalized and not (normalized[0].isalpha() or normalized[0] == '_'):
|
||||
normalized = 'kb_' + normalized
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
async def _ensure_vector_index(self, collection: str) -> None:
|
||||
@@ -70,15 +70,11 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
"""
|
||||
index_params = IndexParams()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="AUTOINDEX",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.client.create_index,
|
||||
collection_name=collection,
|
||||
index_params=index_params
|
||||
field_name='vector',
|
||||
index_type='AUTOINDEX',
|
||||
metric_type='COSINE',
|
||||
)
|
||||
await asyncio.to_thread(self.client.create_index, collection_name=collection, index_params=index_params)
|
||||
|
||||
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None):
|
||||
"""Internal method to get or create a Milvus collection with proper configuration.
|
||||
@@ -89,14 +85,12 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
"""
|
||||
# Normalize collection name for Milvus compatibility
|
||||
collection = self._normalize_collection_name(collection)
|
||||
|
||||
|
||||
if collection in self._collections:
|
||||
return collection
|
||||
|
||||
# Check if collection exists
|
||||
has_collection = await asyncio.to_thread(
|
||||
self.client.has_collection, collection_name=collection
|
||||
)
|
||||
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
|
||||
|
||||
if not has_collection:
|
||||
# Default dimension if not specified (for backward compatibility)
|
||||
@@ -104,24 +98,26 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
vector_size = 1536
|
||||
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name='id', dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
||||
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=255),
|
||||
FieldSchema(name='chunk_uuid', dtype=DataType.VARCHAR, max_length=255),
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors")
|
||||
schema = CollectionSchema(fields=fields, description='LangBot knowledge base vectors')
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
collection_name=collection,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
metric_type='COSINE',
|
||||
)
|
||||
|
||||
await self._ensure_vector_index(collection)
|
||||
self.ap.logger.info(f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX")
|
||||
self.ap.logger.info(
|
||||
f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX"
|
||||
)
|
||||
else:
|
||||
# Ensure index exists for existing collection
|
||||
await self._ensure_index_if_missing(collection)
|
||||
@@ -137,11 +133,8 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
collection: Normalized collection name
|
||||
"""
|
||||
try:
|
||||
indexes = await asyncio.to_thread(
|
||||
self.client.list_indexes,
|
||||
collection_name=collection
|
||||
)
|
||||
if "vector" not in indexes:
|
||||
indexes = await asyncio.to_thread(self.client.list_indexes, collection_name=collection)
|
||||
if 'vector' not in indexes:
|
||||
await self._ensure_vector_index(collection)
|
||||
self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'")
|
||||
except Exception as e:
|
||||
@@ -172,7 +165,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
metadatas: List of metadata dictionaries for each vector
|
||||
"""
|
||||
collection = self._normalize_collection_name(collection)
|
||||
|
||||
|
||||
if not embeddings_list:
|
||||
return
|
||||
|
||||
@@ -184,39 +177,30 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
data = []
|
||||
for i, vector_id in enumerate(ids):
|
||||
entry = {
|
||||
"id": vector_id,
|
||||
"vector": embeddings_list[i],
|
||||
'id': vector_id,
|
||||
'vector': embeddings_list[i],
|
||||
}
|
||||
# Add metadata fields
|
||||
if metadatas and i < len(metadatas):
|
||||
metadata = metadatas[i]
|
||||
# Add common metadata fields
|
||||
if "text" in metadata:
|
||||
entry["text"] = metadata["text"]
|
||||
if "file_id" in metadata:
|
||||
entry["file_id"] = metadata["file_id"]
|
||||
if "uuid" in metadata:
|
||||
entry["chunk_uuid"] = metadata["uuid"]
|
||||
if 'text' in metadata:
|
||||
entry['text'] = metadata['text']
|
||||
if 'file_id' in metadata:
|
||||
entry['file_id'] = metadata['file_id']
|
||||
if 'uuid' in metadata:
|
||||
entry['chunk_uuid'] = metadata['uuid']
|
||||
data.append(entry)
|
||||
|
||||
# Insert data into Milvus
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
collection_name=collection,
|
||||
data=data
|
||||
)
|
||||
await asyncio.to_thread(self.client.insert, collection_name=collection, data=data)
|
||||
|
||||
# Load collection for searching (Milvus requires this)
|
||||
await asyncio.to_thread(
|
||||
self.client.load_collection,
|
||||
collection_name=collection
|
||||
)
|
||||
await asyncio.to_thread(self.client.load_collection, collection_name=collection)
|
||||
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
||||
|
||||
async def search(
|
||||
self, collection: str, query_embedding: list[float], k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
||||
"""Search for similar vectors in Milvus collection
|
||||
|
||||
Args:
|
||||
@@ -231,10 +215,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Perform search
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {}
|
||||
}
|
||||
search_params = {'metric_type': 'COSINE', 'params': {}}
|
||||
|
||||
results = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
@@ -242,7 +223,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
data=[query_embedding],
|
||||
limit=k,
|
||||
search_params=search_params,
|
||||
output_fields=["text", "file_id", "chunk_uuid"]
|
||||
output_fields=['text', 'file_id', 'chunk_uuid'],
|
||||
)
|
||||
|
||||
# Convert results to Chroma-compatible format
|
||||
@@ -253,30 +234,24 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
|
||||
if results and len(results) > 0:
|
||||
for hit in results[0]:
|
||||
ids.append(hit.get("id", ""))
|
||||
distances.append(hit.get("distance", 0.0))
|
||||
ids.append(hit.get('id', ''))
|
||||
distances.append(hit.get('distance', 0.0))
|
||||
|
||||
# Build metadata from entity fields
|
||||
entity = hit.get("entity", {})
|
||||
entity = hit.get('entity', {})
|
||||
metadata = {}
|
||||
if "text" in entity:
|
||||
metadata["text"] = entity["text"]
|
||||
if "file_id" in entity:
|
||||
metadata["file_id"] = entity["file_id"]
|
||||
if "chunk_uuid" in entity:
|
||||
metadata["uuid"] = entity["chunk_uuid"]
|
||||
if 'text' in entity:
|
||||
metadata['text'] = entity['text']
|
||||
if 'file_id' in entity:
|
||||
metadata['file_id'] = entity['file_id']
|
||||
if 'chunk_uuid' in entity:
|
||||
metadata['uuid'] = entity['chunk_uuid']
|
||||
metadatas.append(metadata)
|
||||
|
||||
# Return in Chroma-compatible format (nested lists)
|
||||
result = {
|
||||
"ids": [ids],
|
||||
"distances": [distances],
|
||||
"metadatas": [metadatas]
|
||||
}
|
||||
result = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
|
||||
|
||||
self.ap.logger.info(
|
||||
f"Milvus search in '{collection}' returned {len(ids)} results"
|
||||
)
|
||||
self.ap.logger.info(f"Milvus search in '{collection}' returned {len(ids)} results")
|
||||
return result
|
||||
|
||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||
@@ -290,14 +265,8 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
await self.get_or_create_collection(collection)
|
||||
|
||||
# Delete entities matching the file_id
|
||||
await asyncio.to_thread(
|
||||
self.client.delete,
|
||||
collection_name=collection,
|
||||
filter=f'file_id == "{file_id}"'
|
||||
)
|
||||
self.ap.logger.info(
|
||||
f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}"
|
||||
)
|
||||
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
|
||||
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
|
||||
|
||||
async def delete_collection(self, collection: str):
|
||||
"""Delete a Milvus collection
|
||||
@@ -306,18 +275,14 @@ class MilvusVectorDatabase(VectorDatabase):
|
||||
collection: Collection name to delete
|
||||
"""
|
||||
collection = self._normalize_collection_name(collection)
|
||||
|
||||
|
||||
self._collections.discard(collection)
|
||||
|
||||
# Check if collection exists before attempting deletion
|
||||
has_collection = await asyncio.to_thread(
|
||||
self.client.has_collection, collection_name=collection
|
||||
)
|
||||
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
|
||||
|
||||
if has_collection:
|
||||
await asyncio.to_thread(
|
||||
self.client.drop_collection, collection_name=collection
|
||||
)
|
||||
await asyncio.to_thread(self.client.drop_collection, collection_name=collection)
|
||||
self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
|
||||
else:
|
||||
self.ap.logger.warning(f"Milvus collection '{collection}' not found")
|
||||
|
||||
@@ -9,27 +9,28 @@ from typing import Any
|
||||
|
||||
def _apply_env_overrides_to_config(cfg: dict) -> dict:
|
||||
"""Apply environment variable overrides to data/config.yaml
|
||||
|
||||
Environment variables should be uppercase and use __ (double underscore)
|
||||
|
||||
Environment variables should be uppercase and use __ (double underscore)
|
||||
to represent nested keys. For example:
|
||||
- CONCURRENCY__PIPELINE overrides concurrency.pipeline
|
||||
- PLUGIN__RUNTIME_WS_URL overrides plugin.runtime_ws_url
|
||||
|
||||
|
||||
Arrays and dict types are ignored.
|
||||
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Updated configuration dictionary
|
||||
"""
|
||||
|
||||
def convert_value(value: str, original_value: Any) -> Any:
|
||||
"""Convert string value to appropriate type based on original value
|
||||
|
||||
|
||||
Args:
|
||||
value: String value from environment variable
|
||||
original_value: Original value to infer type from
|
||||
|
||||
|
||||
Returns:
|
||||
Converted value (falls back to string if conversion fails)
|
||||
"""
|
||||
@@ -49,7 +50,7 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
|
||||
return value
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
# Process environment variables
|
||||
for env_key, env_value in os.environ.items():
|
||||
# Check if the environment variable is uppercase and contains __
|
||||
@@ -57,18 +58,18 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
|
||||
continue
|
||||
if '__' not in env_key:
|
||||
continue
|
||||
|
||||
|
||||
# Convert environment variable name to config path
|
||||
# e.g., CONCURRENCY__PIPELINE -> ['concurrency', 'pipeline']
|
||||
keys = [key.lower() for key in env_key.split('__')]
|
||||
|
||||
|
||||
# Navigate to the target value and validate the path
|
||||
current = cfg
|
||||
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
if not isinstance(current, dict) or key not in current:
|
||||
break
|
||||
|
||||
|
||||
if i == len(keys) - 1:
|
||||
# At the final key - check if it's a scalar value
|
||||
if isinstance(current[key], (dict, list)):
|
||||
@@ -81,248 +82,182 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
|
||||
else:
|
||||
# Navigate deeper
|
||||
current = current[key]
|
||||
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
class TestEnvOverrides:
|
||||
"""Test environment variable override functionality"""
|
||||
|
||||
|
||||
def test_simple_string_override(self):
|
||||
"""Test overriding a simple string value"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'api': {'port': 5300}}
|
||||
|
||||
# Set environment variable
|
||||
os.environ['API__PORT'] = '8080'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
assert result['api']['port'] == 8080
|
||||
|
||||
|
||||
# Cleanup
|
||||
del os.environ['API__PORT']
|
||||
|
||||
|
||||
def test_nested_key_override(self):
|
||||
"""Test overriding nested keys with __ delimiter"""
|
||||
cfg = {
|
||||
'concurrency': {
|
||||
'pipeline': 20,
|
||||
'session': 1
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'concurrency': {'pipeline': 20, 'session': 1}}
|
||||
|
||||
os.environ['CONCURRENCY__PIPELINE'] = '50'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
assert result['concurrency']['pipeline'] == 50
|
||||
assert result['concurrency']['session'] == 1 # Unchanged
|
||||
|
||||
|
||||
del os.environ['CONCURRENCY__PIPELINE']
|
||||
|
||||
|
||||
def test_deep_nested_override(self):
|
||||
"""Test overriding deeply nested keys"""
|
||||
cfg = {
|
||||
'system': {
|
||||
'jwt': {
|
||||
'expire': 604800,
|
||||
'secret': ''
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'system': {'jwt': {'expire': 604800, 'secret': ''}}}
|
||||
|
||||
os.environ['SYSTEM__JWT__EXPIRE'] = '86400'
|
||||
os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
assert result['system']['jwt']['expire'] == 86400
|
||||
assert result['system']['jwt']['secret'] == 'my_secret_key'
|
||||
|
||||
|
||||
del os.environ['SYSTEM__JWT__EXPIRE']
|
||||
del os.environ['SYSTEM__JWT__SECRET']
|
||||
|
||||
|
||||
def test_underscore_in_key(self):
|
||||
"""Test keys with underscores like runtime_ws_url"""
|
||||
cfg = {
|
||||
'plugin': {
|
||||
'enable': True,
|
||||
'runtime_ws_url': 'ws://localhost:5400/control/ws'
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'plugin': {'enable': True, 'runtime_ws_url': 'ws://localhost:5400/control/ws'}}
|
||||
|
||||
os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
assert result['plugin']['runtime_ws_url'] == 'ws://newhost:6000/ws'
|
||||
|
||||
|
||||
del os.environ['PLUGIN__RUNTIME_WS_URL']
|
||||
|
||||
|
||||
def test_boolean_conversion(self):
|
||||
"""Test boolean value conversion"""
|
||||
cfg = {
|
||||
'plugin': {
|
||||
'enable': True,
|
||||
'enable_marketplace': False
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'plugin': {'enable': True, 'enable_marketplace': False}}
|
||||
|
||||
os.environ['PLUGIN__ENABLE'] = 'false'
|
||||
os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
assert result['plugin']['enable'] is False
|
||||
assert result['plugin']['enable_marketplace'] is True
|
||||
|
||||
|
||||
del os.environ['PLUGIN__ENABLE']
|
||||
del os.environ['PLUGIN__ENABLE_MARKETPLACE']
|
||||
|
||||
|
||||
def test_ignore_dict_type(self):
|
||||
"""Test that dict types are ignored"""
|
||||
cfg = {
|
||||
'database': {
|
||||
'use': 'sqlite',
|
||||
'sqlite': {
|
||||
'path': 'data/langbot.db'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
|
||||
|
||||
# Try to override a dict value - should be ignored
|
||||
os.environ['DATABASE__SQLITE'] = 'new_value'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
# Should remain a dict, not overridden
|
||||
assert isinstance(result['database']['sqlite'], dict)
|
||||
assert result['database']['sqlite']['path'] == 'data/langbot.db'
|
||||
|
||||
|
||||
del os.environ['DATABASE__SQLITE']
|
||||
|
||||
|
||||
def test_ignore_list_type(self):
|
||||
"""Test that list/array types are ignored"""
|
||||
cfg = {
|
||||
'admins': ['admin1', 'admin2'],
|
||||
'command': {
|
||||
'enable': True,
|
||||
'prefix': ['!', '!']
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'admins': ['admin1', 'admin2'], 'command': {'enable': True, 'prefix': ['!', '!']}}
|
||||
|
||||
# Try to override list values - should be ignored
|
||||
os.environ['ADMINS'] = 'admin3'
|
||||
os.environ['COMMAND__PREFIX'] = '?'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
# Should remain lists, not overridden
|
||||
assert isinstance(result['admins'], list)
|
||||
assert result['admins'] == ['admin1', 'admin2']
|
||||
assert isinstance(result['command']['prefix'], list)
|
||||
assert result['command']['prefix'] == ['!', '!']
|
||||
|
||||
|
||||
del os.environ['ADMINS']
|
||||
del os.environ['COMMAND__PREFIX']
|
||||
|
||||
|
||||
def test_lowercase_env_var_ignored(self):
|
||||
"""Test that lowercase environment variables are ignored"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'api': {'port': 5300}}
|
||||
|
||||
os.environ['api__port'] = '8080'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
# Should not be overridden
|
||||
assert result['api']['port'] == 5300
|
||||
|
||||
|
||||
del os.environ['api__port']
|
||||
|
||||
|
||||
def test_no_double_underscore_ignored(self):
|
||||
"""Test that env vars without __ are ignored"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'api': {'port': 5300}}
|
||||
|
||||
os.environ['APIPORT'] = '8080'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
# Should not be overridden
|
||||
assert result['api']['port'] == 5300
|
||||
|
||||
|
||||
del os.environ['APIPORT']
|
||||
|
||||
|
||||
def test_nonexistent_key_ignored(self):
|
||||
"""Test that env vars for non-existent keys are ignored"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'api': {'port': 5300}}
|
||||
|
||||
os.environ['API__NONEXISTENT'] = 'value'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
# Should not create new key
|
||||
assert 'nonexistent' not in result['api']
|
||||
|
||||
|
||||
del os.environ['API__NONEXISTENT']
|
||||
|
||||
|
||||
def test_integer_conversion(self):
|
||||
"""Test integer value conversion"""
|
||||
cfg = {
|
||||
'concurrency': {
|
||||
'pipeline': 20
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'concurrency': {'pipeline': 20}}
|
||||
|
||||
os.environ['CONCURRENCY__PIPELINE'] = '100'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
assert result['concurrency']['pipeline'] == 100
|
||||
assert isinstance(result['concurrency']['pipeline'], int)
|
||||
|
||||
|
||||
del os.environ['CONCURRENCY__PIPELINE']
|
||||
|
||||
|
||||
def test_multiple_overrides(self):
|
||||
"""Test multiple environment variable overrides at once"""
|
||||
cfg = {
|
||||
'api': {
|
||||
'port': 5300
|
||||
},
|
||||
'concurrency': {
|
||||
'pipeline': 20,
|
||||
'session': 1
|
||||
},
|
||||
'plugin': {
|
||||
'enable': False
|
||||
}
|
||||
}
|
||||
|
||||
cfg = {'api': {'port': 5300}, 'concurrency': {'pipeline': 20, 'session': 1}, 'plugin': {'enable': False}}
|
||||
|
||||
os.environ['API__PORT'] = '8080'
|
||||
os.environ['CONCURRENCY__PIPELINE'] = '50'
|
||||
os.environ['PLUGIN__ENABLE'] = 'true'
|
||||
|
||||
|
||||
result = _apply_env_overrides_to_config(cfg)
|
||||
|
||||
|
||||
assert result['api']['port'] == 8080
|
||||
assert result['concurrency']['pipeline'] == 50
|
||||
assert result['plugin']['enable'] is True
|
||||
|
||||
|
||||
del os.environ['API__PORT']
|
||||
del os.environ['CONCURRENCY__PIPELINE']
|
||||
del os.environ['PLUGIN__ENABLE']
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Test plugin list filtering by component kinds."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
@@ -31,16 +30,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Tool',
|
||||
'metadata': {'name': 'tool1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -53,15 +43,8 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'KnowledgeRetriever',
|
||||
'metadata': {'name': 'retriever1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
||||
],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -73,16 +56,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Command',
|
||||
'metadata': {'name': 'cmd1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Command', 'metadata': {'name': 'cmd1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -94,16 +68,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'EventListener',
|
||||
'metadata': {'name': 'listener1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'EventListener', 'metadata': {'name': 'listener1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -116,23 +81,9 @@ async def test_plugin_list_filter_by_component_kinds():
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'KnowledgeRetriever',
|
||||
'metadata': {'name': 'retriever2'}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Tool',
|
||||
'metadata': {'name': 'tool2'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever2'}}}},
|
||||
{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool2'}}}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -187,16 +138,7 @@ async def test_plugin_list_filter_no_filter():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Tool',
|
||||
'metadata': {'name': 'tool1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -209,15 +151,8 @@ async def test_plugin_list_filter_no_filter():
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'KnowledgeRetriever',
|
||||
'metadata': {'name': 'retriever1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -267,15 +202,8 @@ async def test_plugin_list_filter_empty_result():
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'KnowledgeRetriever',
|
||||
'metadata': {'name': 'retriever1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -321,16 +249,7 @@ async def test_plugin_list_filter_plugin_without_components():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': [
|
||||
{
|
||||
'manifest': {
|
||||
'manifest': {
|
||||
'kind': 'Tool',
|
||||
'metadata': {'name': 'tool1'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
|
||||
},
|
||||
{
|
||||
'debug': False,
|
||||
@@ -342,7 +261,7 @@ async def test_plugin_list_filter_plugin_without_components():
|
||||
}
|
||||
}
|
||||
},
|
||||
'components': []
|
||||
'components': [],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user