mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: add GitHub Actions workflow for linting with Ruff (#1929)
* feat: add GitHub Actions workflow for linting with Ruff * refactor: rename lint job and add formatting step to Ruff workflow * chore: run ruff format * chore: rename Ruff lint job to 'Lint' and add frontend linting workflow
This commit is contained in:
committed by
GitHub
parent
e60cb6ad0e
commit
fc6e414be4
60
.github/workflows/lint.yml
vendored
Normal file
60
.github/workflows/lint.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
name: Lint
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- master
|
||||||
|
- dev
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ruff:
|
||||||
|
name: Ruff Lint & Format
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync --dev
|
||||||
|
|
||||||
|
- name: Run ruff check
|
||||||
|
run: uv run ruff check src
|
||||||
|
|
||||||
|
- name: Run ruff format
|
||||||
|
run: uv run ruff format src --check
|
||||||
|
|
||||||
|
frontend:
|
||||||
|
name: Frontend Lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: '25'
|
||||||
|
|
||||||
|
- name: Install pnpm
|
||||||
|
uses: pnpm/action-setup@v4
|
||||||
|
with:
|
||||||
|
version: 9
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
working-directory: web
|
||||||
|
run: pnpm install
|
||||||
|
|
||||||
|
- name: Run lint
|
||||||
|
working-directory: web
|
||||||
|
run: pnpm lint
|
||||||
@@ -85,7 +85,6 @@ class QQOfficialClient:
|
|||||||
req: Quart Request 对象
|
req: Quart Request 对象
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
body = await req.get_data()
|
body = await req.get_data()
|
||||||
|
|
||||||
print(f'[QQ Official] Received request, body length: {len(body)}')
|
print(f'[QQ Official] Received request, body length: {len(body)}')
|
||||||
@@ -96,7 +95,6 @@ class QQOfficialClient:
|
|||||||
|
|
||||||
payload = json.loads(body)
|
payload = json.loads(body)
|
||||||
|
|
||||||
|
|
||||||
if payload.get('op') == 13:
|
if payload.get('op') == 13:
|
||||||
validation_data = payload.get('d')
|
validation_data = payload.get('d')
|
||||||
if not validation_data:
|
if not validation_data:
|
||||||
@@ -276,21 +274,21 @@ class QQOfficialClient:
|
|||||||
seed = bot_secret
|
seed = bot_secret
|
||||||
while len(seed) < target_size:
|
while len(seed) < target_size:
|
||||||
seed *= 2
|
seed *= 2
|
||||||
return seed[:target_size].encode("utf-8")
|
return seed[:target_size].encode('utf-8')
|
||||||
|
|
||||||
async def verify(self, validation_payload: dict):
|
async def verify(self, validation_payload: dict):
|
||||||
seed = await self.repeat_seed(self.secret)
|
seed = await self.repeat_seed(self.secret)
|
||||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||||
|
|
||||||
event_ts = validation_payload.get("event_ts", "")
|
event_ts = validation_payload.get('event_ts', '')
|
||||||
plain_token = validation_payload.get("plain_token", "")
|
plain_token = validation_payload.get('plain_token', '')
|
||||||
msg = event_ts + plain_token
|
msg = event_ts + plain_token
|
||||||
|
|
||||||
# sign
|
# sign
|
||||||
signature = private_key.sign(msg.encode()).hex()
|
signature = private_key.sign(msg.encode()).hex()
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"plain_token": plain_token,
|
'plain_token': plain_token,
|
||||||
"signature": signature,
|
'signature': signature,
|
||||||
}
|
}
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -36,7 +36,12 @@ class WecomBotEvent(dict):
|
|||||||
"""
|
"""
|
||||||
用户名称
|
用户名称
|
||||||
"""
|
"""
|
||||||
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid
|
return (
|
||||||
|
self.get('username', '')
|
||||||
|
or self.get('from', {}).get('alias', '')
|
||||||
|
or self.get('from', {}).get('name', '')
|
||||||
|
or self.userid
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chatname(self) -> str:
|
def chatname(self) -> str:
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ class WebhookRouterGroup(group.RouterGroup):
|
|||||||
适配器返回的响应
|
适配器返回的响应
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||||
|
|
||||||
if not runtime_bot:
|
if not runtime_bot:
|
||||||
@@ -39,11 +38,9 @@ class WebhookRouterGroup(group.RouterGroup):
|
|||||||
if not runtime_bot.enable:
|
if not runtime_bot.enable:
|
||||||
return quart.jsonify({'error': 'Bot is disabled'}), 403
|
return quart.jsonify({'error': 'Bot is disabled'}), 403
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
|
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
|
||||||
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
|
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
|
||||||
|
|
||||||
|
|
||||||
response = await runtime_bot.adapter.handle_unified_webhook(
|
response = await runtime_bot.adapter.handle_unified_webhook(
|
||||||
bot_uuid=bot_uuid,
|
bot_uuid=bot_uuid,
|
||||||
path=path,
|
path=path,
|
||||||
|
|||||||
@@ -59,7 +59,16 @@ class BotService:
|
|||||||
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
||||||
|
|
||||||
# Webhook URL for unified webhook adapters (independent of bot running state)
|
# Webhook URL for unified webhook adapters (independent of bot running state)
|
||||||
if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack', 'wecomcs', 'LINE', 'lark']:
|
if persistence_bot['adapter'] in [
|
||||||
|
'wecom',
|
||||||
|
'wecombot',
|
||||||
|
'officialaccount',
|
||||||
|
'qqofficial',
|
||||||
|
'slack',
|
||||||
|
'wecomcs',
|
||||||
|
'LINE',
|
||||||
|
'lark',
|
||||||
|
]:
|
||||||
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
|
||||||
webhook_url = f'/bots/{bot_uuid}'
|
webhook_url = f'/bots/{bot_uuid}'
|
||||||
adapter_runtime_values['webhook_url'] = webhook_url
|
adapter_runtime_values['webhook_url'] = webhook_url
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ from .. import taskmgr
|
|||||||
from ...telemetry import telemetry as telemetry_module
|
from ...telemetry import telemetry as telemetry_module
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('BuildAppStage')
|
@stage.stage_class('BuildAppStage')
|
||||||
class BuildAppStage(stage.BootingStage):
|
class BuildAppStage(stage.BootingStage):
|
||||||
"""Build LangBot application"""
|
"""Build LangBot application"""
|
||||||
|
|||||||
@@ -244,7 +244,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
|
|
||||||
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
|
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
|
||||||
|
|
||||||
|
|
||||||
if message.message_type == 'text':
|
if message.message_type == 'text':
|
||||||
element_list = []
|
element_list = []
|
||||||
|
|
||||||
@@ -310,7 +309,11 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
]
|
]
|
||||||
elif message.message_type == 'audio':
|
elif message.message_type == 'audio':
|
||||||
message_content['content'] = [
|
message_content['content'] = [
|
||||||
{'tag': 'audio', 'file_key': message_content['file_key'], "duration": message_content.get('duration',0)}
|
{
|
||||||
|
'tag': 'audio',
|
||||||
|
'file_key': message_content['file_key'],
|
||||||
|
'duration': message_content.get('duration', 0),
|
||||||
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
for ele in message_content['content']:
|
for ele in message_content['content']:
|
||||||
@@ -367,12 +370,9 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
audio_bytes = response.file.read()
|
audio_bytes = response.file.read()
|
||||||
audio_base64 = base64.b64encode(audio_bytes).decode()
|
audio_base64 = base64.b64encode(audio_bytes).decode()
|
||||||
|
|
||||||
|
|
||||||
# Get content type from response headers
|
# Get content type from response headers
|
||||||
content_type = response.raw.headers.get('content-type', 'audio/mpeg')
|
content_type = response.raw.headers.get('content-type', 'audio/mpeg')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
mime_main = content_type.split(';')[0].strip()
|
mime_main = content_type.split(';')[0].strip()
|
||||||
ext = mimetypes.guess_extension(mime_main) or '.bin'
|
ext = mimetypes.guess_extension(mime_main) or '.bin'
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
@@ -418,7 +418,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
file_bytes = response.file.read()
|
file_bytes = response.file.read()
|
||||||
file_base64 = base64.b64encode(file_bytes).decode()
|
file_base64 = base64.b64encode(file_bytes).decode()
|
||||||
|
|
||||||
|
|
||||||
file_format = response.raw.headers['content-type']
|
file_format = response.raw.headers['content-type']
|
||||||
|
|
||||||
file_size = len(file_bytes)
|
file_size = len(file_bytes)
|
||||||
@@ -453,7 +452,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return platform_message.MessageChain(lb_msg_list)
|
return platform_message.MessageChain(lb_msg_list)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -75,13 +75,15 @@ class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
|||||||
for msg in message_chain:
|
for msg in message_chain:
|
||||||
if type(msg) is platform_message.Plain:
|
if type(msg) is platform_message.Plain:
|
||||||
chunks = split_string_by_bytes(msg.text)
|
chunks = split_string_by_bytes(msg.text)
|
||||||
content_list.extend([
|
content_list.extend(
|
||||||
|
[
|
||||||
{
|
{
|
||||||
'type': 'text',
|
'type': 'text',
|
||||||
'content': chunk,
|
'content': chunk,
|
||||||
}
|
}
|
||||||
for chunk in chunks
|
for chunk in chunks
|
||||||
])
|
]
|
||||||
|
)
|
||||||
elif type(msg) is platform_message.Image:
|
elif type(msg) is platform_message.Image:
|
||||||
content_list.append(
|
content_list.append(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -51,9 +51,10 @@ class SeekDBEmbedding(requester.ProviderAPIRequester):
|
|||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
||||||
if self._embedding_function is None:
|
if self._embedding_function is None:
|
||||||
raise RuntimeError("SeekDB embedding function initialization failed")
|
raise RuntimeError('SeekDB embedding function initialization failed')
|
||||||
|
|
||||||
return self._embedding_function(input_text)
|
return self._embedding_function(input_text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from .. import errors
|
from .. import errors
|
||||||
|
|
||||||
raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}')
|
raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}')
|
||||||
|
|||||||
@@ -221,7 +221,11 @@ class LocalAgentRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
# Handle return value content
|
# Handle return value content
|
||||||
tool_content = None
|
tool_content = None
|
||||||
if isinstance(func_ret, list) and len(func_ret) > 0 and isinstance(func_ret[0], provider_message.ContentElement):
|
if (
|
||||||
|
isinstance(func_ret, list)
|
||||||
|
and len(func_ret) > 0
|
||||||
|
and isinstance(func_ret[0], provider_message.ContentElement)
|
||||||
|
):
|
||||||
tool_content = func_ret
|
tool_content = func_ret
|
||||||
else:
|
else:
|
||||||
tool_content = json.dumps(func_ret, ensure_ascii=False)
|
tool_content = json.dumps(func_ret, ensure_ascii=False)
|
||||||
|
|||||||
@@ -68,15 +68,16 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
return plain_text
|
return plain_text
|
||||||
|
|
||||||
async def _process_stream_response(self, response: aiohttp.ClientResponse) -> typing.AsyncGenerator[
|
async def _process_stream_response(
|
||||||
provider_message.Message, None]:
|
self, response: aiohttp.ClientResponse
|
||||||
|
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||||
"""处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况"""
|
"""处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况"""
|
||||||
full_content = ""
|
full_content = ''
|
||||||
chunk_idx = 0
|
chunk_idx = 0
|
||||||
is_final = False
|
is_final = False
|
||||||
message_idx = 0
|
message_idx = 0
|
||||||
|
|
||||||
buffer = ""
|
buffer = ''
|
||||||
decoder = json.JSONDecoder()
|
decoder = json.JSONDecoder()
|
||||||
|
|
||||||
async for raw_chunk in response.content.iter_chunked(1024):
|
async for raw_chunk in response.content.iter_chunked(1024):
|
||||||
@@ -129,7 +130,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
preview = chunk_str[:200]
|
preview = chunk_str[:200]
|
||||||
except Exception:
|
except Exception:
|
||||||
preview = '<unavailable>'
|
preview = '<unavailable>'
|
||||||
self.ap.logger.warning(f"Failed to process chunk: {e}; chunk preview: {preview}")
|
self.ap.logger.warning(f'Failed to process chunk: {e}; chunk preview: {preview}')
|
||||||
|
|
||||||
# 流结束后,尝试解析残余 buffer
|
# 流结束后,尝试解析残余 buffer
|
||||||
if buffer:
|
if buffer:
|
||||||
@@ -151,7 +152,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
preview = buffer[:200]
|
preview = buffer[:200]
|
||||||
self.ap.logger.warning(f"Failed to parse remaining buffer: {e}; buffer preview: {preview}")
|
self.ap.logger.warning(f'Failed to parse remaining buffer: {e}; buffer preview: {preview}')
|
||||||
|
|
||||||
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||||
"""调用n8n webhook"""
|
"""调用n8n webhook"""
|
||||||
@@ -165,7 +166,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
# 准备请求数据
|
# 准备请求数据
|
||||||
payload = {
|
payload = {
|
||||||
# 基本消息内容
|
# 基本消息内容
|
||||||
'chatInput' :plain_text, # 考虑到之前用户直接用的message model这里添加新键
|
'chatInput': plain_text, # 考虑到之前用户直接用的message model这里添加新键
|
||||||
'message': plain_text,
|
'message': plain_text,
|
||||||
'user_message_text': plain_text,
|
'user_message_text': plain_text,
|
||||||
'conversation_id': query.session.using_conversation.uuid,
|
'conversation_id': query.session.using_conversation.uuid,
|
||||||
@@ -220,11 +221,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
if is_stream:
|
if is_stream:
|
||||||
# 流式请求
|
# 流式请求
|
||||||
async with session.post(
|
async with session.post(
|
||||||
self.webhook_url,
|
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||||
json=payload,
|
|
||||||
headers=headers,
|
|
||||||
auth=auth,
|
|
||||||
timeout=self.timeout
|
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
@@ -236,11 +233,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
self.webhook_url,
|
self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
|
||||||
json=payload,
|
|
||||||
headers=headers,
|
|
||||||
auth=auth,
|
|
||||||
timeout=self.timeout
|
|
||||||
) as response:
|
) as response:
|
||||||
try:
|
try:
|
||||||
async for chunk in self._process_stream_response(response):
|
async for chunk in self._process_stream_response(response):
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ class RuntimeMCPSession:
|
|||||||
|
|
||||||
async def func(*, _tool=tool, **kwargs):
|
async def func(*, _tool=tool, **kwargs):
|
||||||
if not self.session:
|
if not self.session:
|
||||||
raise Exception("MCP session is not connected")
|
raise Exception('MCP session is not connected')
|
||||||
|
|
||||||
result = await self.session.call_tool(_tool.name, kwargs)
|
result = await self.session.call_tool(_tool.name, kwargs)
|
||||||
if result.isError:
|
if result.isError:
|
||||||
@@ -202,7 +202,7 @@ class RuntimeMCPSession:
|
|||||||
for content in result.content:
|
for content in result.content:
|
||||||
if content.type == 'text':
|
if content.type == 'text':
|
||||||
error_texts.append(content.text)
|
error_texts.append(content.text)
|
||||||
raise Exception("\n".join(error_texts) if error_texts else "Unknown error from MCP tool")
|
raise Exception('\n'.join(error_texts) if error_texts else 'Unknown error from MCP tool')
|
||||||
|
|
||||||
result_contents: list[provider_message.ContentElement] = []
|
result_contents: list[provider_message.ContentElement] = []
|
||||||
for content in result.content:
|
for content in result.content:
|
||||||
@@ -221,8 +221,8 @@ class RuntimeMCPSession:
|
|||||||
self.functions.append(
|
self.functions.append(
|
||||||
resource_tool.LLMTool(
|
resource_tool.LLMTool(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
human_desc=tool.description or "",
|
human_desc=tool.description or '',
|
||||||
description=tool.description or "",
|
description=tool.description or '',
|
||||||
parameters=tool.inputSchema,
|
parameters=tool.inputSchema,
|
||||||
func=func,
|
func=func,
|
||||||
)
|
)
|
||||||
@@ -338,13 +338,10 @@ class MCPLoader(loader.ToolLoader):
|
|||||||
"""
|
"""
|
||||||
uuid_ = server_config.get('uuid')
|
uuid_ = server_config.get('uuid')
|
||||||
if not uuid_:
|
if not uuid_:
|
||||||
self.ap.logger.warning(
|
self.ap.logger.warning('Server UUID is None for MCP server, maybe testing in the config page.')
|
||||||
'Server UUID is None for MCP server, maybe testing in the config page.'
|
|
||||||
)
|
|
||||||
uuid_ = str(uuid_module.uuid4())
|
uuid_ = str(uuid_module.uuid4())
|
||||||
server_config['uuid'] = uuid_
|
server_config['uuid'] = uuid_
|
||||||
|
|
||||||
|
|
||||||
name = server_config['name']
|
name = server_config['name']
|
||||||
uuid = server_config['uuid']
|
uuid = server_config['uuid']
|
||||||
mode = server_config['mode']
|
mode = server_config['mode']
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class Embedder(BaseService):
|
|||||||
embeddings_list: list[list[float]] = []
|
embeddings_list: list[list[float]] = []
|
||||||
|
|
||||||
for i in range(0, len(chunks), MAX_BATCH_SIZE):
|
for i in range(0, len(chunks), MAX_BATCH_SIZE):
|
||||||
batch = chunks[i:i + MAX_BATCH_SIZE]
|
batch = chunks[i : i + MAX_BATCH_SIZE]
|
||||||
batch_embeddings = await embedding_model.provider.requester.invoke_embedding(
|
batch_embeddings = await embedding_model.provider.requester.invoke_embedding(
|
||||||
model=embedding_model,
|
model=embedding_model,
|
||||||
input_text=batch,
|
input_text=batch,
|
||||||
|
|||||||
@@ -55,12 +55,7 @@ class VectorDBManager:
|
|||||||
user = pgvector_config.get('user', 'postgres')
|
user = pgvector_config.get('user', 'postgres')
|
||||||
password = pgvector_config.get('password', 'postgres')
|
password = pgvector_config.get('password', 'postgres')
|
||||||
self.vector_db = PgVectorDatabase(
|
self.vector_db = PgVectorDatabase(
|
||||||
self.ap,
|
self.ap, host=host, port=port, database=database, user=user, password=password
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
database=database,
|
|
||||||
user=user,
|
|
||||||
password=password
|
|
||||||
)
|
)
|
||||||
self.ap.logger.info('Initialized pgvector database backend.')
|
self.ap.logger.info('Initialized pgvector database backend.')
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from langbot.pkg.core import app
|
|||||||
class MilvusVectorDatabase(VectorDatabase):
|
class MilvusVectorDatabase(VectorDatabase):
|
||||||
"""Milvus vector database implementation"""
|
"""Milvus vector database implementation"""
|
||||||
|
|
||||||
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None, db_name: str = None):
|
def __init__(self, ap: app.Application, uri: str = 'milvus.db', token: str = None, db_name: str = None):
|
||||||
"""Initialize Milvus vector database
|
"""Initialize Milvus vector database
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -34,9 +34,9 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name)
|
self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name)
|
||||||
else:
|
else:
|
||||||
self.client = MilvusClient(uri=self.uri, db_name=self.db_name)
|
self.client = MilvusClient(uri=self.uri, db_name=self.db_name)
|
||||||
self.ap.logger.info(f"Connected to Milvus at {self.uri}")
|
self.ap.logger.info(f'Connected to Milvus at {self.uri}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ap.logger.error(f"Failed to connect to Milvus: {e}")
|
self.ap.logger.error(f'Failed to connect to Milvus: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -70,15 +70,11 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
"""
|
"""
|
||||||
index_params = IndexParams()
|
index_params = IndexParams()
|
||||||
index_params.add_index(
|
index_params.add_index(
|
||||||
field_name="vector",
|
field_name='vector',
|
||||||
index_type="AUTOINDEX",
|
index_type='AUTOINDEX',
|
||||||
metric_type="COSINE",
|
metric_type='COSINE',
|
||||||
)
|
|
||||||
await asyncio.to_thread(
|
|
||||||
self.client.create_index,
|
|
||||||
collection_name=collection,
|
|
||||||
index_params=index_params
|
|
||||||
)
|
)
|
||||||
|
await asyncio.to_thread(self.client.create_index, collection_name=collection, index_params=index_params)
|
||||||
|
|
||||||
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None):
|
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None):
|
||||||
"""Internal method to get or create a Milvus collection with proper configuration.
|
"""Internal method to get or create a Milvus collection with proper configuration.
|
||||||
@@ -94,9 +90,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
return collection
|
return collection
|
||||||
|
|
||||||
# Check if collection exists
|
# Check if collection exists
|
||||||
has_collection = await asyncio.to_thread(
|
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
|
||||||
self.client.has_collection, collection_name=collection
|
|
||||||
)
|
|
||||||
|
|
||||||
if not has_collection:
|
if not has_collection:
|
||||||
# Default dimension if not specified (for backward compatibility)
|
# Default dimension if not specified (for backward compatibility)
|
||||||
@@ -104,24 +98,26 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
vector_size = 1536
|
vector_size = 1536
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
FieldSchema(name='id', dtype=DataType.VARCHAR, is_primary=True, max_length=255),
|
||||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
|
||||||
FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255),
|
FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=255),
|
||||||
FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255),
|
FieldSchema(name='chunk_uuid', dtype=DataType.VARCHAR, max_length=255),
|
||||||
]
|
]
|
||||||
|
|
||||||
schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors")
|
schema = CollectionSchema(fields=fields, description='LangBot knowledge base vectors')
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.client.create_collection,
|
self.client.create_collection,
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
metric_type="COSINE",
|
metric_type='COSINE',
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._ensure_vector_index(collection)
|
await self._ensure_vector_index(collection)
|
||||||
self.ap.logger.info(f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX")
|
self.ap.logger.info(
|
||||||
|
f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Ensure index exists for existing collection
|
# Ensure index exists for existing collection
|
||||||
await self._ensure_index_if_missing(collection)
|
await self._ensure_index_if_missing(collection)
|
||||||
@@ -137,11 +133,8 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
collection: Normalized collection name
|
collection: Normalized collection name
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
indexes = await asyncio.to_thread(
|
indexes = await asyncio.to_thread(self.client.list_indexes, collection_name=collection)
|
||||||
self.client.list_indexes,
|
if 'vector' not in indexes:
|
||||||
collection_name=collection
|
|
||||||
)
|
|
||||||
if "vector" not in indexes:
|
|
||||||
await self._ensure_vector_index(collection)
|
await self._ensure_vector_index(collection)
|
||||||
self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'")
|
self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -184,39 +177,30 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
data = []
|
data = []
|
||||||
for i, vector_id in enumerate(ids):
|
for i, vector_id in enumerate(ids):
|
||||||
entry = {
|
entry = {
|
||||||
"id": vector_id,
|
'id': vector_id,
|
||||||
"vector": embeddings_list[i],
|
'vector': embeddings_list[i],
|
||||||
}
|
}
|
||||||
# Add metadata fields
|
# Add metadata fields
|
||||||
if metadatas and i < len(metadatas):
|
if metadatas and i < len(metadatas):
|
||||||
metadata = metadatas[i]
|
metadata = metadatas[i]
|
||||||
# Add common metadata fields
|
# Add common metadata fields
|
||||||
if "text" in metadata:
|
if 'text' in metadata:
|
||||||
entry["text"] = metadata["text"]
|
entry['text'] = metadata['text']
|
||||||
if "file_id" in metadata:
|
if 'file_id' in metadata:
|
||||||
entry["file_id"] = metadata["file_id"]
|
entry['file_id'] = metadata['file_id']
|
||||||
if "uuid" in metadata:
|
if 'uuid' in metadata:
|
||||||
entry["chunk_uuid"] = metadata["uuid"]
|
entry['chunk_uuid'] = metadata['uuid']
|
||||||
data.append(entry)
|
data.append(entry)
|
||||||
|
|
||||||
# Insert data into Milvus
|
# Insert data into Milvus
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(self.client.insert, collection_name=collection, data=data)
|
||||||
self.client.insert,
|
|
||||||
collection_name=collection,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load collection for searching (Milvus requires this)
|
# Load collection for searching (Milvus requires this)
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(self.client.load_collection, collection_name=collection)
|
||||||
self.client.load_collection,
|
|
||||||
collection_name=collection
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
|
||||||
|
|
||||||
async def search(
|
async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
|
||||||
self, collection: str, query_embedding: list[float], k: int = 5
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Search for similar vectors in Milvus collection
|
"""Search for similar vectors in Milvus collection
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -231,10 +215,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
await self.get_or_create_collection(collection)
|
await self.get_or_create_collection(collection)
|
||||||
|
|
||||||
# Perform search
|
# Perform search
|
||||||
search_params = {
|
search_params = {'metric_type': 'COSINE', 'params': {}}
|
||||||
"metric_type": "COSINE",
|
|
||||||
"params": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
results = await asyncio.to_thread(
|
results = await asyncio.to_thread(
|
||||||
self.client.search,
|
self.client.search,
|
||||||
@@ -242,7 +223,7 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
data=[query_embedding],
|
data=[query_embedding],
|
||||||
limit=k,
|
limit=k,
|
||||||
search_params=search_params,
|
search_params=search_params,
|
||||||
output_fields=["text", "file_id", "chunk_uuid"]
|
output_fields=['text', 'file_id', 'chunk_uuid'],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert results to Chroma-compatible format
|
# Convert results to Chroma-compatible format
|
||||||
@@ -253,30 +234,24 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
|
|
||||||
if results and len(results) > 0:
|
if results and len(results) > 0:
|
||||||
for hit in results[0]:
|
for hit in results[0]:
|
||||||
ids.append(hit.get("id", ""))
|
ids.append(hit.get('id', ''))
|
||||||
distances.append(hit.get("distance", 0.0))
|
distances.append(hit.get('distance', 0.0))
|
||||||
|
|
||||||
# Build metadata from entity fields
|
# Build metadata from entity fields
|
||||||
entity = hit.get("entity", {})
|
entity = hit.get('entity', {})
|
||||||
metadata = {}
|
metadata = {}
|
||||||
if "text" in entity:
|
if 'text' in entity:
|
||||||
metadata["text"] = entity["text"]
|
metadata['text'] = entity['text']
|
||||||
if "file_id" in entity:
|
if 'file_id' in entity:
|
||||||
metadata["file_id"] = entity["file_id"]
|
metadata['file_id'] = entity['file_id']
|
||||||
if "chunk_uuid" in entity:
|
if 'chunk_uuid' in entity:
|
||||||
metadata["uuid"] = entity["chunk_uuid"]
|
metadata['uuid'] = entity['chunk_uuid']
|
||||||
metadatas.append(metadata)
|
metadatas.append(metadata)
|
||||||
|
|
||||||
# Return in Chroma-compatible format (nested lists)
|
# Return in Chroma-compatible format (nested lists)
|
||||||
result = {
|
result = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
|
||||||
"ids": [ids],
|
|
||||||
"distances": [distances],
|
|
||||||
"metadatas": [metadatas]
|
|
||||||
}
|
|
||||||
|
|
||||||
self.ap.logger.info(
|
self.ap.logger.info(f"Milvus search in '{collection}' returned {len(ids)} results")
|
||||||
f"Milvus search in '{collection}' returned {len(ids)} results"
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
async def delete_by_file_id(self, collection: str, file_id: str) -> None:
|
||||||
@@ -290,14 +265,8 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
await self.get_or_create_collection(collection)
|
await self.get_or_create_collection(collection)
|
||||||
|
|
||||||
# Delete entities matching the file_id
|
# Delete entities matching the file_id
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
|
||||||
self.client.delete,
|
self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
|
||||||
collection_name=collection,
|
|
||||||
filter=f'file_id == "{file_id}"'
|
|
||||||
)
|
|
||||||
self.ap.logger.info(
|
|
||||||
f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def delete_collection(self, collection: str):
|
async def delete_collection(self, collection: str):
|
||||||
"""Delete a Milvus collection
|
"""Delete a Milvus collection
|
||||||
@@ -310,14 +279,10 @@ class MilvusVectorDatabase(VectorDatabase):
|
|||||||
self._collections.discard(collection)
|
self._collections.discard(collection)
|
||||||
|
|
||||||
# Check if collection exists before attempting deletion
|
# Check if collection exists before attempting deletion
|
||||||
has_collection = await asyncio.to_thread(
|
has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
|
||||||
self.client.has_collection, collection_name=collection
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_collection:
|
if has_collection:
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(self.client.drop_collection, collection_name=collection)
|
||||||
self.client.drop_collection, collection_name=collection
|
|
||||||
)
|
|
||||||
self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
|
self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
|
||||||
else:
|
else:
|
||||||
self.ap.logger.warning(f"Milvus collection '{collection}' not found")
|
self.ap.logger.warning(f"Milvus collection '{collection}' not found")
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
Updated configuration dictionary
|
Updated configuration dictionary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def convert_value(value: str, original_value: Any) -> Any:
|
def convert_value(value: str, original_value: Any) -> Any:
|
||||||
"""Convert string value to appropriate type based on original value
|
"""Convert string value to appropriate type based on original value
|
||||||
|
|
||||||
@@ -90,11 +91,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_simple_string_override(self):
|
def test_simple_string_override(self):
|
||||||
"""Test overriding a simple string value"""
|
"""Test overriding a simple string value"""
|
||||||
cfg = {
|
cfg = {'api': {'port': 5300}}
|
||||||
'api': {
|
|
||||||
'port': 5300
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Set environment variable
|
# Set environment variable
|
||||||
os.environ['API__PORT'] = '8080'
|
os.environ['API__PORT'] = '8080'
|
||||||
@@ -108,12 +105,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_nested_key_override(self):
|
def test_nested_key_override(self):
|
||||||
"""Test overriding nested keys with __ delimiter"""
|
"""Test overriding nested keys with __ delimiter"""
|
||||||
cfg = {
|
cfg = {'concurrency': {'pipeline': 20, 'session': 1}}
|
||||||
'concurrency': {
|
|
||||||
'pipeline': 20,
|
|
||||||
'session': 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
os.environ['CONCURRENCY__PIPELINE'] = '50'
|
os.environ['CONCURRENCY__PIPELINE'] = '50'
|
||||||
|
|
||||||
@@ -126,14 +118,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_deep_nested_override(self):
|
def test_deep_nested_override(self):
|
||||||
"""Test overriding deeply nested keys"""
|
"""Test overriding deeply nested keys"""
|
||||||
cfg = {
|
cfg = {'system': {'jwt': {'expire': 604800, 'secret': ''}}}
|
||||||
'system': {
|
|
||||||
'jwt': {
|
|
||||||
'expire': 604800,
|
|
||||||
'secret': ''
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
os.environ['SYSTEM__JWT__EXPIRE'] = '86400'
|
os.environ['SYSTEM__JWT__EXPIRE'] = '86400'
|
||||||
os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key'
|
os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key'
|
||||||
@@ -148,12 +133,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_underscore_in_key(self):
|
def test_underscore_in_key(self):
|
||||||
"""Test keys with underscores like runtime_ws_url"""
|
"""Test keys with underscores like runtime_ws_url"""
|
||||||
cfg = {
|
cfg = {'plugin': {'enable': True, 'runtime_ws_url': 'ws://localhost:5400/control/ws'}}
|
||||||
'plugin': {
|
|
||||||
'enable': True,
|
|
||||||
'runtime_ws_url': 'ws://localhost:5400/control/ws'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws'
|
os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws'
|
||||||
|
|
||||||
@@ -165,12 +145,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_boolean_conversion(self):
|
def test_boolean_conversion(self):
|
||||||
"""Test boolean value conversion"""
|
"""Test boolean value conversion"""
|
||||||
cfg = {
|
cfg = {'plugin': {'enable': True, 'enable_marketplace': False}}
|
||||||
'plugin': {
|
|
||||||
'enable': True,
|
|
||||||
'enable_marketplace': False
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
os.environ['PLUGIN__ENABLE'] = 'false'
|
os.environ['PLUGIN__ENABLE'] = 'false'
|
||||||
os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true'
|
os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true'
|
||||||
@@ -185,14 +160,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_ignore_dict_type(self):
|
def test_ignore_dict_type(self):
|
||||||
"""Test that dict types are ignored"""
|
"""Test that dict types are ignored"""
|
||||||
cfg = {
|
cfg = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
|
||||||
'database': {
|
|
||||||
'use': 'sqlite',
|
|
||||||
'sqlite': {
|
|
||||||
'path': 'data/langbot.db'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Try to override a dict value - should be ignored
|
# Try to override a dict value - should be ignored
|
||||||
os.environ['DATABASE__SQLITE'] = 'new_value'
|
os.environ['DATABASE__SQLITE'] = 'new_value'
|
||||||
@@ -207,13 +175,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_ignore_list_type(self):
|
def test_ignore_list_type(self):
|
||||||
"""Test that list/array types are ignored"""
|
"""Test that list/array types are ignored"""
|
||||||
cfg = {
|
cfg = {'admins': ['admin1', 'admin2'], 'command': {'enable': True, 'prefix': ['!', '!']}}
|
||||||
'admins': ['admin1', 'admin2'],
|
|
||||||
'command': {
|
|
||||||
'enable': True,
|
|
||||||
'prefix': ['!', '!']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Try to override list values - should be ignored
|
# Try to override list values - should be ignored
|
||||||
os.environ['ADMINS'] = 'admin3'
|
os.environ['ADMINS'] = 'admin3'
|
||||||
@@ -232,11 +194,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_lowercase_env_var_ignored(self):
|
def test_lowercase_env_var_ignored(self):
|
||||||
"""Test that lowercase environment variables are ignored"""
|
"""Test that lowercase environment variables are ignored"""
|
||||||
cfg = {
|
cfg = {'api': {'port': 5300}}
|
||||||
'api': {
|
|
||||||
'port': 5300
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
os.environ['api__port'] = '8080'
|
os.environ['api__port'] = '8080'
|
||||||
|
|
||||||
@@ -249,11 +207,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_no_double_underscore_ignored(self):
|
def test_no_double_underscore_ignored(self):
|
||||||
"""Test that env vars without __ are ignored"""
|
"""Test that env vars without __ are ignored"""
|
||||||
cfg = {
|
cfg = {'api': {'port': 5300}}
|
||||||
'api': {
|
|
||||||
'port': 5300
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
os.environ['APIPORT'] = '8080'
|
os.environ['APIPORT'] = '8080'
|
||||||
|
|
||||||
@@ -266,11 +220,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_nonexistent_key_ignored(self):
|
def test_nonexistent_key_ignored(self):
|
||||||
"""Test that env vars for non-existent keys are ignored"""
|
"""Test that env vars for non-existent keys are ignored"""
|
||||||
cfg = {
|
cfg = {'api': {'port': 5300}}
|
||||||
'api': {
|
|
||||||
'port': 5300
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
os.environ['API__NONEXISTENT'] = 'value'
|
os.environ['API__NONEXISTENT'] = 'value'
|
||||||
|
|
||||||
@@ -283,11 +233,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_integer_conversion(self):
|
def test_integer_conversion(self):
|
||||||
"""Test integer value conversion"""
|
"""Test integer value conversion"""
|
||||||
cfg = {
|
cfg = {'concurrency': {'pipeline': 20}}
|
||||||
'concurrency': {
|
|
||||||
'pipeline': 20
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
os.environ['CONCURRENCY__PIPELINE'] = '100'
|
os.environ['CONCURRENCY__PIPELINE'] = '100'
|
||||||
|
|
||||||
@@ -300,18 +246,7 @@ class TestEnvOverrides:
|
|||||||
|
|
||||||
def test_multiple_overrides(self):
|
def test_multiple_overrides(self):
|
||||||
"""Test multiple environment variable overrides at once"""
|
"""Test multiple environment variable overrides at once"""
|
||||||
cfg = {
|
cfg = {'api': {'port': 5300}, 'concurrency': {'pipeline': 20, 'session': 1}, 'plugin': {'enable': False}}
|
||||||
'api': {
|
|
||||||
'port': 5300
|
|
||||||
},
|
|
||||||
'concurrency': {
|
|
||||||
'pipeline': 20,
|
|
||||||
'session': 1
|
|
||||||
},
|
|
||||||
'plugin': {
|
|
||||||
'enable': False
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
os.environ['API__PORT'] = '8080'
|
os.environ['API__PORT'] = '8080'
|
||||||
os.environ['CONCURRENCY__PIPELINE'] = '50'
|
os.environ['CONCURRENCY__PIPELINE'] = '50'
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Test plugin list filtering by component kinds."""
|
"""Test plugin list filtering by component kinds."""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -31,16 +30,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
|
||||||
{
|
|
||||||
'manifest': {
|
|
||||||
'manifest': {
|
|
||||||
'kind': 'Tool',
|
|
||||||
'metadata': {'name': 'tool1'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'debug': False,
|
'debug': False,
|
||||||
@@ -53,15 +43,8 @@ async def test_plugin_list_filter_by_component_kinds():
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [
|
||||||
{
|
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
||||||
'manifest': {
|
],
|
||||||
'manifest': {
|
|
||||||
'kind': 'KnowledgeRetriever',
|
|
||||||
'metadata': {'name': 'retriever1'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'debug': False,
|
'debug': False,
|
||||||
@@ -73,16 +56,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [{'manifest': {'manifest': {'kind': 'Command', 'metadata': {'name': 'cmd1'}}}}],
|
||||||
{
|
|
||||||
'manifest': {
|
|
||||||
'manifest': {
|
|
||||||
'kind': 'Command',
|
|
||||||
'metadata': {'name': 'cmd1'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'debug': False,
|
'debug': False,
|
||||||
@@ -94,16 +68,7 @@ async def test_plugin_list_filter_by_component_kinds():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [{'manifest': {'manifest': {'kind': 'EventListener', 'metadata': {'name': 'listener1'}}}}],
|
||||||
{
|
|
||||||
'manifest': {
|
|
||||||
'manifest': {
|
|
||||||
'kind': 'EventListener',
|
|
||||||
'metadata': {'name': 'listener1'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'debug': False,
|
'debug': False,
|
||||||
@@ -116,23 +81,9 @@ async def test_plugin_list_filter_by_component_kinds():
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [
|
||||||
{
|
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever2'}}}},
|
||||||
'manifest': {
|
{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool2'}}}},
|
||||||
'manifest': {
|
],
|
||||||
'kind': 'KnowledgeRetriever',
|
|
||||||
'metadata': {'name': 'retriever2'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'manifest': {
|
|
||||||
'manifest': {
|
|
||||||
'kind': 'Tool',
|
|
||||||
'metadata': {'name': 'tool2'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -187,16 +138,7 @@ async def test_plugin_list_filter_no_filter():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
|
||||||
{
|
|
||||||
'manifest': {
|
|
||||||
'manifest': {
|
|
||||||
'kind': 'Tool',
|
|
||||||
'metadata': {'name': 'tool1'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'debug': False,
|
'debug': False,
|
||||||
@@ -209,15 +151,8 @@ async def test_plugin_list_filter_no_filter():
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [
|
||||||
{
|
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
||||||
'manifest': {
|
],
|
||||||
'manifest': {
|
|
||||||
'kind': 'KnowledgeRetriever',
|
|
||||||
'metadata': {'name': 'retriever1'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -267,15 +202,8 @@ async def test_plugin_list_filter_empty_result():
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [
|
||||||
{
|
{'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
|
||||||
'manifest': {
|
],
|
||||||
'manifest': {
|
|
||||||
'kind': 'KnowledgeRetriever',
|
|
||||||
'metadata': {'name': 'retriever1'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -321,16 +249,7 @@ async def test_plugin_list_filter_plugin_without_components():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': [
|
'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
|
||||||
{
|
|
||||||
'manifest': {
|
|
||||||
'manifest': {
|
|
||||||
'kind': 'Tool',
|
|
||||||
'metadata': {'name': 'tool1'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'debug': False,
|
'debug': False,
|
||||||
@@ -342,7 +261,7 @@ async def test_plugin_list_filter_plugin_without_components():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'components': []
|
'components': [],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user