feat: add GitHub Actions workflow for linting with Ruff (#1929)

* feat: add GitHub Actions workflow for linting with Ruff

* refactor: rename lint job and add formatting step to Ruff workflow

* chore: run ruff format

* chore: rename Ruff lint job to 'Lint' and add frontend linting workflow
This commit is contained in:
Junyan Qin (Chin)
2026-01-23 13:43:12 +08:00
committed by GitHub
parent e60cb6ad0e
commit fc6e414be4
17 changed files with 327 additions and 450 deletions

60
.github/workflows/lint.yml vendored Normal file
View File

@@ -0,0 +1,60 @@
name: Lint
on:
push:
branches:
- main
- master
- dev
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
jobs:
ruff:
name: Ruff Lint & Format
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install dependencies
run: uv sync --dev
- name: Run ruff check
run: uv run ruff check src
- name: Run ruff format
run: uv run ruff format src --check
frontend:
name: Frontend Lint
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '25'
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: Install dependencies
working-directory: web
run: pnpm install
- name: Run lint
working-directory: web
run: pnpm lint

View File

@@ -85,7 +85,6 @@ class QQOfficialClient:
req: Quart Request 对象 req: Quart Request 对象
""" """
try: try:
body = await req.get_data() body = await req.get_data()
print(f'[QQ Official] Received request, body length: {len(body)}') print(f'[QQ Official] Received request, body length: {len(body)}')
@@ -96,7 +95,6 @@ class QQOfficialClient:
payload = json.loads(body) payload = json.loads(body)
if payload.get('op') == 13: if payload.get('op') == 13:
validation_data = payload.get('d') validation_data = payload.get('d')
if not validation_data: if not validation_data:
@@ -276,21 +274,21 @@ class QQOfficialClient:
seed = bot_secret seed = bot_secret
while len(seed) < target_size: while len(seed) < target_size:
seed *= 2 seed *= 2
return seed[:target_size].encode("utf-8") return seed[:target_size].encode('utf-8')
async def verify(self, validation_payload: dict): async def verify(self, validation_payload: dict):
seed = await self.repeat_seed(self.secret) seed = await self.repeat_seed(self.secret)
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed) private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
event_ts = validation_payload.get("event_ts", "") event_ts = validation_payload.get('event_ts', '')
plain_token = validation_payload.get("plain_token", "") plain_token = validation_payload.get('plain_token', '')
msg = event_ts + plain_token msg = event_ts + plain_token
# sign # sign
signature = private_key.sign(msg.encode()).hex() signature = private_key.sign(msg.encode()).hex()
response = { response = {
"plain_token": plain_token, 'plain_token': plain_token,
"signature": signature, 'signature': signature,
} }
return response return response

View File

@@ -36,7 +36,12 @@ class WecomBotEvent(dict):
""" """
用户名称 用户名称
""" """
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid return (
self.get('username', '')
or self.get('from', {}).get('alias', '')
or self.get('from', {}).get('name', '')
or self.userid
)
@property @property
def chatname(self) -> str: def chatname(self) -> str:

View File

@@ -30,7 +30,6 @@ class WebhookRouterGroup(group.RouterGroup):
适配器返回的响应 适配器返回的响应
""" """
try: try:
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid) runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
if not runtime_bot: if not runtime_bot:
@@ -39,11 +38,9 @@ class WebhookRouterGroup(group.RouterGroup):
if not runtime_bot.enable: if not runtime_bot.enable:
return quart.jsonify({'error': 'Bot is disabled'}), 403 return quart.jsonify({'error': 'Bot is disabled'}), 403
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'): if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501 return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
response = await runtime_bot.adapter.handle_unified_webhook( response = await runtime_bot.adapter.handle_unified_webhook(
bot_uuid=bot_uuid, bot_uuid=bot_uuid,
path=path, path=path,

View File

@@ -59,7 +59,16 @@ class BotService:
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
# Webhook URL for unified webhook adapters (independent of bot running state) # Webhook URL for unified webhook adapters (independent of bot running state)
if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack', 'wecomcs', 'LINE', 'lark']: if persistence_bot['adapter'] in [
'wecom',
'wecombot',
'officialaccount',
'qqofficial',
'slack',
'wecomcs',
'LINE',
'lark',
]:
webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300') webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300')
webhook_url = f'/bots/{bot_uuid}' webhook_url = f'/bots/{bot_uuid}'
adapter_runtime_values['webhook_url'] = webhook_url adapter_runtime_values['webhook_url'] = webhook_url

View File

@@ -34,7 +34,6 @@ from .. import taskmgr
from ...telemetry import telemetry as telemetry_module from ...telemetry import telemetry as telemetry_module
@stage.stage_class('BuildAppStage') @stage.stage_class('BuildAppStage')
class BuildAppStage(stage.BootingStage): class BuildAppStage(stage.BootingStage):
"""Build LangBot application""" """Build LangBot application"""

View File

@@ -244,7 +244,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time)) lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
if message.message_type == 'text': if message.message_type == 'text':
element_list = [] element_list = []
@@ -310,7 +309,11 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
] ]
elif message.message_type == 'audio': elif message.message_type == 'audio':
message_content['content'] = [ message_content['content'] = [
{'tag': 'audio', 'file_key': message_content['file_key'], "duration": message_content.get('duration',0)} {
'tag': 'audio',
'file_key': message_content['file_key'],
'duration': message_content.get('duration', 0),
}
] ]
for ele in message_content['content']: for ele in message_content['content']:
@@ -367,12 +370,9 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
audio_bytes = response.file.read() audio_bytes = response.file.read()
audio_base64 = base64.b64encode(audio_bytes).decode() audio_base64 = base64.b64encode(audio_bytes).decode()
# Get content type from response headers # Get content type from response headers
content_type = response.raw.headers.get('content-type', 'audio/mpeg') content_type = response.raw.headers.get('content-type', 'audio/mpeg')
mime_main = content_type.split(';')[0].strip() mime_main = content_type.split(';')[0].strip()
ext = mimetypes.guess_extension(mime_main) or '.bin' ext = mimetypes.guess_extension(mime_main) or '.bin'
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
@@ -418,7 +418,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
file_bytes = response.file.read() file_bytes = response.file.read()
file_base64 = base64.b64encode(file_bytes).decode() file_base64 = base64.b64encode(file_bytes).decode()
file_format = response.raw.headers['content-type'] file_format = response.raw.headers['content-type']
file_size = len(file_bytes) file_size = len(file_bytes)
@@ -453,7 +452,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
) )
) )
return platform_message.MessageChain(lb_msg_list) return platform_message.MessageChain(lb_msg_list)

View File

@@ -75,13 +75,15 @@ class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
for msg in message_chain: for msg in message_chain:
if type(msg) is platform_message.Plain: if type(msg) is platform_message.Plain:
chunks = split_string_by_bytes(msg.text) chunks = split_string_by_bytes(msg.text)
content_list.extend([ content_list.extend(
[
{ {
'type': 'text', 'type': 'text',
'content': chunk, 'content': chunk,
} }
for chunk in chunks for chunk in chunks
]) ]
)
elif type(msg) is platform_message.Image: elif type(msg) is platform_message.Image:
content_list.append( content_list.append(
{ {

View File

@@ -51,9 +51,10 @@ class SeekDBEmbedding(requester.ProviderAPIRequester):
await self.initialize() await self.initialize()
if self._embedding_function is None: if self._embedding_function is None:
raise RuntimeError("SeekDB embedding function initialization failed") raise RuntimeError('SeekDB embedding function initialization failed')
return self._embedding_function(input_text) return self._embedding_function(input_text)
except Exception as e: except Exception as e:
from .. import errors from .. import errors
raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}') raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}')

View File

@@ -221,7 +221,11 @@ class LocalAgentRunner(runner.RequestRunner):
# Handle return value content # Handle return value content
tool_content = None tool_content = None
if isinstance(func_ret, list) and len(func_ret) > 0 and isinstance(func_ret[0], provider_message.ContentElement): if (
isinstance(func_ret, list)
and len(func_ret) > 0
and isinstance(func_ret[0], provider_message.ContentElement)
):
tool_content = func_ret tool_content = func_ret
else: else:
tool_content = json.dumps(func_ret, ensure_ascii=False) tool_content = json.dumps(func_ret, ensure_ascii=False)

View File

@@ -68,15 +68,16 @@ class N8nServiceAPIRunner(runner.RequestRunner):
return plain_text return plain_text
async def _process_stream_response(self, response: aiohttp.ClientResponse) -> typing.AsyncGenerator[ async def _process_stream_response(
provider_message.Message, None]: self, response: aiohttp.ClientResponse
) -> typing.AsyncGenerator[provider_message.Message, None]:
"""处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况""" """处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况"""
full_content = "" full_content = ''
chunk_idx = 0 chunk_idx = 0
is_final = False is_final = False
message_idx = 0 message_idx = 0
buffer = "" buffer = ''
decoder = json.JSONDecoder() decoder = json.JSONDecoder()
async for raw_chunk in response.content.iter_chunked(1024): async for raw_chunk in response.content.iter_chunked(1024):
@@ -129,7 +130,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
preview = chunk_str[:200] preview = chunk_str[:200]
except Exception: except Exception:
preview = '<unavailable>' preview = '<unavailable>'
self.ap.logger.warning(f"Failed to process chunk: {e}; chunk preview: {preview}") self.ap.logger.warning(f'Failed to process chunk: {e}; chunk preview: {preview}')
# 流结束后,尝试解析残余 buffer # 流结束后,尝试解析残余 buffer
if buffer: if buffer:
@@ -151,7 +152,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
) )
except Exception as e: except Exception as e:
preview = buffer[:200] preview = buffer[:200]
self.ap.logger.warning(f"Failed to parse remaining buffer: {e}; buffer preview: {preview}") self.ap.logger.warning(f'Failed to parse remaining buffer: {e}; buffer preview: {preview}')
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
"""调用n8n webhook""" """调用n8n webhook"""
@@ -165,7 +166,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
# 准备请求数据 # 准备请求数据
payload = { payload = {
# 基本消息内容 # 基本消息内容
'chatInput' :plain_text, # 考虑到之前用户直接用的message model这里添加新键 'chatInput': plain_text, # 考虑到之前用户直接用的message model这里添加新键
'message': plain_text, 'message': plain_text,
'user_message_text': plain_text, 'user_message_text': plain_text,
'conversation_id': query.session.using_conversation.uuid, 'conversation_id': query.session.using_conversation.uuid,
@@ -220,11 +221,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
if is_stream: if is_stream:
# 流式请求 # 流式请求
async with session.post( async with session.post(
self.webhook_url, self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
json=payload,
headers=headers,
auth=auth,
timeout=self.timeout
) as response: ) as response:
if response.status != 200: if response.status != 200:
error_text = await response.text() error_text = await response.text()
@@ -236,11 +233,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
yield chunk yield chunk
else: else:
async with session.post( async with session.post(
self.webhook_url, self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout
json=payload,
headers=headers,
auth=auth,
timeout=self.timeout
) as response: ) as response:
try: try:
async for chunk in self._process_stream_response(response): async for chunk in self._process_stream_response(response):

View File

@@ -194,7 +194,7 @@ class RuntimeMCPSession:
async def func(*, _tool=tool, **kwargs): async def func(*, _tool=tool, **kwargs):
if not self.session: if not self.session:
raise Exception("MCP session is not connected") raise Exception('MCP session is not connected')
result = await self.session.call_tool(_tool.name, kwargs) result = await self.session.call_tool(_tool.name, kwargs)
if result.isError: if result.isError:
@@ -202,7 +202,7 @@ class RuntimeMCPSession:
for content in result.content: for content in result.content:
if content.type == 'text': if content.type == 'text':
error_texts.append(content.text) error_texts.append(content.text)
raise Exception("\n".join(error_texts) if error_texts else "Unknown error from MCP tool") raise Exception('\n'.join(error_texts) if error_texts else 'Unknown error from MCP tool')
result_contents: list[provider_message.ContentElement] = [] result_contents: list[provider_message.ContentElement] = []
for content in result.content: for content in result.content:
@@ -221,8 +221,8 @@ class RuntimeMCPSession:
self.functions.append( self.functions.append(
resource_tool.LLMTool( resource_tool.LLMTool(
name=tool.name, name=tool.name,
human_desc=tool.description or "", human_desc=tool.description or '',
description=tool.description or "", description=tool.description or '',
parameters=tool.inputSchema, parameters=tool.inputSchema,
func=func, func=func,
) )
@@ -338,13 +338,10 @@ class MCPLoader(loader.ToolLoader):
""" """
uuid_ = server_config.get('uuid') uuid_ = server_config.get('uuid')
if not uuid_: if not uuid_:
self.ap.logger.warning( self.ap.logger.warning('Server UUID is None for MCP server, maybe testing in the config page.')
'Server UUID is None for MCP server, maybe testing in the config page.'
)
uuid_ = str(uuid_module.uuid4()) uuid_ = str(uuid_module.uuid4())
server_config['uuid'] = uuid_ server_config['uuid'] = uuid_
name = server_config['name'] name = server_config['name']
uuid = server_config['uuid'] uuid = server_config['uuid']
mode = server_config['mode'] mode = server_config['mode']

View File

@@ -37,7 +37,7 @@ class Embedder(BaseService):
embeddings_list: list[list[float]] = [] embeddings_list: list[list[float]] = []
for i in range(0, len(chunks), MAX_BATCH_SIZE): for i in range(0, len(chunks), MAX_BATCH_SIZE):
batch = chunks[i:i + MAX_BATCH_SIZE] batch = chunks[i : i + MAX_BATCH_SIZE]
batch_embeddings = await embedding_model.provider.requester.invoke_embedding( batch_embeddings = await embedding_model.provider.requester.invoke_embedding(
model=embedding_model, model=embedding_model,
input_text=batch, input_text=batch,

View File

@@ -55,12 +55,7 @@ class VectorDBManager:
user = pgvector_config.get('user', 'postgres') user = pgvector_config.get('user', 'postgres')
password = pgvector_config.get('password', 'postgres') password = pgvector_config.get('password', 'postgres')
self.vector_db = PgVectorDatabase( self.vector_db = PgVectorDatabase(
self.ap, self.ap, host=host, port=port, database=database, user=user, password=password
host=host,
port=port,
database=database,
user=user,
password=password
) )
self.ap.logger.info('Initialized pgvector database backend.') self.ap.logger.info('Initialized pgvector database backend.')

View File

@@ -10,7 +10,7 @@ from langbot.pkg.core import app
class MilvusVectorDatabase(VectorDatabase): class MilvusVectorDatabase(VectorDatabase):
"""Milvus vector database implementation""" """Milvus vector database implementation"""
def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None, db_name: str = None): def __init__(self, ap: app.Application, uri: str = 'milvus.db', token: str = None, db_name: str = None):
"""Initialize Milvus vector database """Initialize Milvus vector database
Args: Args:
@@ -34,9 +34,9 @@ class MilvusVectorDatabase(VectorDatabase):
self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name) self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name)
else: else:
self.client = MilvusClient(uri=self.uri, db_name=self.db_name) self.client = MilvusClient(uri=self.uri, db_name=self.db_name)
self.ap.logger.info(f"Connected to Milvus at {self.uri}") self.ap.logger.info(f'Connected to Milvus at {self.uri}')
except Exception as e: except Exception as e:
self.ap.logger.error(f"Failed to connect to Milvus: {e}") self.ap.logger.error(f'Failed to connect to Milvus: {e}')
raise raise
@staticmethod @staticmethod
@@ -70,15 +70,11 @@ class MilvusVectorDatabase(VectorDatabase):
""" """
index_params = IndexParams() index_params = IndexParams()
index_params.add_index( index_params.add_index(
field_name="vector", field_name='vector',
index_type="AUTOINDEX", index_type='AUTOINDEX',
metric_type="COSINE", metric_type='COSINE',
)
await asyncio.to_thread(
self.client.create_index,
collection_name=collection,
index_params=index_params
) )
await asyncio.to_thread(self.client.create_index, collection_name=collection, index_params=index_params)
async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None): async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None):
"""Internal method to get or create a Milvus collection with proper configuration. """Internal method to get or create a Milvus collection with proper configuration.
@@ -94,9 +90,7 @@ class MilvusVectorDatabase(VectorDatabase):
return collection return collection
# Check if collection exists # Check if collection exists
has_collection = await asyncio.to_thread( has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
self.client.has_collection, collection_name=collection
)
if not has_collection: if not has_collection:
# Default dimension if not specified (for backward compatibility) # Default dimension if not specified (for backward compatibility)
@@ -104,24 +98,26 @@ class MilvusVectorDatabase(VectorDatabase):
vector_size = 1536 vector_size = 1536
fields = [ fields = [
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255), FieldSchema(name='id', dtype=DataType.VARCHAR, is_primary=True, max_length=255),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size), FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255), FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=255),
FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255), FieldSchema(name='chunk_uuid', dtype=DataType.VARCHAR, max_length=255),
] ]
schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors") schema = CollectionSchema(fields=fields, description='LangBot knowledge base vectors')
await asyncio.to_thread( await asyncio.to_thread(
self.client.create_collection, self.client.create_collection,
collection_name=collection, collection_name=collection,
schema=schema, schema=schema,
metric_type="COSINE", metric_type='COSINE',
) )
await self._ensure_vector_index(collection) await self._ensure_vector_index(collection)
self.ap.logger.info(f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX") self.ap.logger.info(
f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX"
)
else: else:
# Ensure index exists for existing collection # Ensure index exists for existing collection
await self._ensure_index_if_missing(collection) await self._ensure_index_if_missing(collection)
@@ -137,11 +133,8 @@ class MilvusVectorDatabase(VectorDatabase):
collection: Normalized collection name collection: Normalized collection name
""" """
try: try:
indexes = await asyncio.to_thread( indexes = await asyncio.to_thread(self.client.list_indexes, collection_name=collection)
self.client.list_indexes, if 'vector' not in indexes:
collection_name=collection
)
if "vector" not in indexes:
await self._ensure_vector_index(collection) await self._ensure_vector_index(collection)
self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'") self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'")
except Exception as e: except Exception as e:
@@ -184,39 +177,30 @@ class MilvusVectorDatabase(VectorDatabase):
data = [] data = []
for i, vector_id in enumerate(ids): for i, vector_id in enumerate(ids):
entry = { entry = {
"id": vector_id, 'id': vector_id,
"vector": embeddings_list[i], 'vector': embeddings_list[i],
} }
# Add metadata fields # Add metadata fields
if metadatas and i < len(metadatas): if metadatas and i < len(metadatas):
metadata = metadatas[i] metadata = metadatas[i]
# Add common metadata fields # Add common metadata fields
if "text" in metadata: if 'text' in metadata:
entry["text"] = metadata["text"] entry['text'] = metadata['text']
if "file_id" in metadata: if 'file_id' in metadata:
entry["file_id"] = metadata["file_id"] entry['file_id'] = metadata['file_id']
if "uuid" in metadata: if 'uuid' in metadata:
entry["chunk_uuid"] = metadata["uuid"] entry['chunk_uuid'] = metadata['uuid']
data.append(entry) data.append(entry)
# Insert data into Milvus # Insert data into Milvus
await asyncio.to_thread( await asyncio.to_thread(self.client.insert, collection_name=collection, data=data)
self.client.insert,
collection_name=collection,
data=data
)
# Load collection for searching (Milvus requires this) # Load collection for searching (Milvus requires this)
await asyncio.to_thread( await asyncio.to_thread(self.client.load_collection, collection_name=collection)
self.client.load_collection,
collection_name=collection
)
self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'") self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'")
async def search( async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]:
self, collection: str, query_embedding: list[float], k: int = 5
) -> Dict[str, Any]:
"""Search for similar vectors in Milvus collection """Search for similar vectors in Milvus collection
Args: Args:
@@ -231,10 +215,7 @@ class MilvusVectorDatabase(VectorDatabase):
await self.get_or_create_collection(collection) await self.get_or_create_collection(collection)
# Perform search # Perform search
search_params = { search_params = {'metric_type': 'COSINE', 'params': {}}
"metric_type": "COSINE",
"params": {}
}
results = await asyncio.to_thread( results = await asyncio.to_thread(
self.client.search, self.client.search,
@@ -242,7 +223,7 @@ class MilvusVectorDatabase(VectorDatabase):
data=[query_embedding], data=[query_embedding],
limit=k, limit=k,
search_params=search_params, search_params=search_params,
output_fields=["text", "file_id", "chunk_uuid"] output_fields=['text', 'file_id', 'chunk_uuid'],
) )
# Convert results to Chroma-compatible format # Convert results to Chroma-compatible format
@@ -253,30 +234,24 @@ class MilvusVectorDatabase(VectorDatabase):
if results and len(results) > 0: if results and len(results) > 0:
for hit in results[0]: for hit in results[0]:
ids.append(hit.get("id", "")) ids.append(hit.get('id', ''))
distances.append(hit.get("distance", 0.0)) distances.append(hit.get('distance', 0.0))
# Build metadata from entity fields # Build metadata from entity fields
entity = hit.get("entity", {}) entity = hit.get('entity', {})
metadata = {} metadata = {}
if "text" in entity: if 'text' in entity:
metadata["text"] = entity["text"] metadata['text'] = entity['text']
if "file_id" in entity: if 'file_id' in entity:
metadata["file_id"] = entity["file_id"] metadata['file_id'] = entity['file_id']
if "chunk_uuid" in entity: if 'chunk_uuid' in entity:
metadata["uuid"] = entity["chunk_uuid"] metadata['uuid'] = entity['chunk_uuid']
metadatas.append(metadata) metadatas.append(metadata)
# Return in Chroma-compatible format (nested lists) # Return in Chroma-compatible format (nested lists)
result = { result = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]}
"ids": [ids],
"distances": [distances],
"metadatas": [metadatas]
}
self.ap.logger.info( self.ap.logger.info(f"Milvus search in '{collection}' returned {len(ids)} results")
f"Milvus search in '{collection}' returned {len(ids)} results"
)
return result return result
async def delete_by_file_id(self, collection: str, file_id: str) -> None: async def delete_by_file_id(self, collection: str, file_id: str) -> None:
@@ -290,14 +265,8 @@ class MilvusVectorDatabase(VectorDatabase):
await self.get_or_create_collection(collection) await self.get_or_create_collection(collection)
# Delete entities matching the file_id # Delete entities matching the file_id
await asyncio.to_thread( await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"')
self.client.delete, self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}")
collection_name=collection,
filter=f'file_id == "{file_id}"'
)
self.ap.logger.info(
f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}"
)
async def delete_collection(self, collection: str): async def delete_collection(self, collection: str):
"""Delete a Milvus collection """Delete a Milvus collection
@@ -310,14 +279,10 @@ class MilvusVectorDatabase(VectorDatabase):
self._collections.discard(collection) self._collections.discard(collection)
# Check if collection exists before attempting deletion # Check if collection exists before attempting deletion
has_collection = await asyncio.to_thread( has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection)
self.client.has_collection, collection_name=collection
)
if has_collection: if has_collection:
await asyncio.to_thread( await asyncio.to_thread(self.client.drop_collection, collection_name=collection)
self.client.drop_collection, collection_name=collection
)
self.ap.logger.info(f"Deleted Milvus collection '{collection}'") self.ap.logger.info(f"Deleted Milvus collection '{collection}'")
else: else:
self.ap.logger.warning(f"Milvus collection '{collection}' not found") self.ap.logger.warning(f"Milvus collection '{collection}' not found")

View File

@@ -23,6 +23,7 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict:
Returns: Returns:
Updated configuration dictionary Updated configuration dictionary
""" """
def convert_value(value: str, original_value: Any) -> Any: def convert_value(value: str, original_value: Any) -> Any:
"""Convert string value to appropriate type based on original value """Convert string value to appropriate type based on original value
@@ -90,11 +91,7 @@ class TestEnvOverrides:
def test_simple_string_override(self): def test_simple_string_override(self):
"""Test overriding a simple string value""" """Test overriding a simple string value"""
cfg = { cfg = {'api': {'port': 5300}}
'api': {
'port': 5300
}
}
# Set environment variable # Set environment variable
os.environ['API__PORT'] = '8080' os.environ['API__PORT'] = '8080'
@@ -108,12 +105,7 @@ class TestEnvOverrides:
def test_nested_key_override(self): def test_nested_key_override(self):
"""Test overriding nested keys with __ delimiter""" """Test overriding nested keys with __ delimiter"""
cfg = { cfg = {'concurrency': {'pipeline': 20, 'session': 1}}
'concurrency': {
'pipeline': 20,
'session': 1
}
}
os.environ['CONCURRENCY__PIPELINE'] = '50' os.environ['CONCURRENCY__PIPELINE'] = '50'
@@ -126,14 +118,7 @@ class TestEnvOverrides:
def test_deep_nested_override(self): def test_deep_nested_override(self):
"""Test overriding deeply nested keys""" """Test overriding deeply nested keys"""
cfg = { cfg = {'system': {'jwt': {'expire': 604800, 'secret': ''}}}
'system': {
'jwt': {
'expire': 604800,
'secret': ''
}
}
}
os.environ['SYSTEM__JWT__EXPIRE'] = '86400' os.environ['SYSTEM__JWT__EXPIRE'] = '86400'
os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key' os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key'
@@ -148,12 +133,7 @@ class TestEnvOverrides:
def test_underscore_in_key(self): def test_underscore_in_key(self):
"""Test keys with underscores like runtime_ws_url""" """Test keys with underscores like runtime_ws_url"""
cfg = { cfg = {'plugin': {'enable': True, 'runtime_ws_url': 'ws://localhost:5400/control/ws'}}
'plugin': {
'enable': True,
'runtime_ws_url': 'ws://localhost:5400/control/ws'
}
}
os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws' os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws'
@@ -165,12 +145,7 @@ class TestEnvOverrides:
def test_boolean_conversion(self): def test_boolean_conversion(self):
"""Test boolean value conversion""" """Test boolean value conversion"""
cfg = { cfg = {'plugin': {'enable': True, 'enable_marketplace': False}}
'plugin': {
'enable': True,
'enable_marketplace': False
}
}
os.environ['PLUGIN__ENABLE'] = 'false' os.environ['PLUGIN__ENABLE'] = 'false'
os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true' os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true'
@@ -185,14 +160,7 @@ class TestEnvOverrides:
def test_ignore_dict_type(self): def test_ignore_dict_type(self):
"""Test that dict types are ignored""" """Test that dict types are ignored"""
cfg = { cfg = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}}
'database': {
'use': 'sqlite',
'sqlite': {
'path': 'data/langbot.db'
}
}
}
# Try to override a dict value - should be ignored # Try to override a dict value - should be ignored
os.environ['DATABASE__SQLITE'] = 'new_value' os.environ['DATABASE__SQLITE'] = 'new_value'
@@ -207,13 +175,7 @@ class TestEnvOverrides:
def test_ignore_list_type(self): def test_ignore_list_type(self):
"""Test that list/array types are ignored""" """Test that list/array types are ignored"""
cfg = { cfg = {'admins': ['admin1', 'admin2'], 'command': {'enable': True, 'prefix': ['!', '']}}
'admins': ['admin1', 'admin2'],
'command': {
'enable': True,
'prefix': ['!', '']
}
}
# Try to override list values - should be ignored # Try to override list values - should be ignored
os.environ['ADMINS'] = 'admin3' os.environ['ADMINS'] = 'admin3'
@@ -232,11 +194,7 @@ class TestEnvOverrides:
def test_lowercase_env_var_ignored(self): def test_lowercase_env_var_ignored(self):
"""Test that lowercase environment variables are ignored""" """Test that lowercase environment variables are ignored"""
cfg = { cfg = {'api': {'port': 5300}}
'api': {
'port': 5300
}
}
os.environ['api__port'] = '8080' os.environ['api__port'] = '8080'
@@ -249,11 +207,7 @@ class TestEnvOverrides:
def test_no_double_underscore_ignored(self): def test_no_double_underscore_ignored(self):
"""Test that env vars without __ are ignored""" """Test that env vars without __ are ignored"""
cfg = { cfg = {'api': {'port': 5300}}
'api': {
'port': 5300
}
}
os.environ['APIPORT'] = '8080' os.environ['APIPORT'] = '8080'
@@ -266,11 +220,7 @@ class TestEnvOverrides:
def test_nonexistent_key_ignored(self): def test_nonexistent_key_ignored(self):
"""Test that env vars for non-existent keys are ignored""" """Test that env vars for non-existent keys are ignored"""
cfg = { cfg = {'api': {'port': 5300}}
'api': {
'port': 5300
}
}
os.environ['API__NONEXISTENT'] = 'value' os.environ['API__NONEXISTENT'] = 'value'
@@ -283,11 +233,7 @@ class TestEnvOverrides:
def test_integer_conversion(self): def test_integer_conversion(self):
"""Test integer value conversion""" """Test integer value conversion"""
cfg = { cfg = {'concurrency': {'pipeline': 20}}
'concurrency': {
'pipeline': 20
}
}
os.environ['CONCURRENCY__PIPELINE'] = '100' os.environ['CONCURRENCY__PIPELINE'] = '100'
@@ -300,18 +246,7 @@ class TestEnvOverrides:
def test_multiple_overrides(self): def test_multiple_overrides(self):
"""Test multiple environment variable overrides at once""" """Test multiple environment variable overrides at once"""
cfg = { cfg = {'api': {'port': 5300}, 'concurrency': {'pipeline': 20, 'session': 1}, 'plugin': {'enable': False}}
'api': {
'port': 5300
},
'concurrency': {
'pipeline': 20,
'session': 1
},
'plugin': {
'enable': False
}
}
os.environ['API__PORT'] = '8080' os.environ['API__PORT'] = '8080'
os.environ['CONCURRENCY__PIPELINE'] = '50' os.environ['CONCURRENCY__PIPELINE'] = '50'

View File

@@ -1,6 +1,5 @@
"""Test plugin list filtering by component kinds.""" """Test plugin list filtering by component kinds."""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@@ -31,16 +30,7 @@ async def test_plugin_list_filter_by_component_kinds():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'Tool',
'metadata': {'name': 'tool1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -53,15 +43,8 @@ async def test_plugin_list_filter_by_component_kinds():
} }
}, },
'components': [ 'components': [
{ {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
'manifest': { ],
'manifest': {
'kind': 'KnowledgeRetriever',
'metadata': {'name': 'retriever1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -73,16 +56,7 @@ async def test_plugin_list_filter_by_component_kinds():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'Command', 'metadata': {'name': 'cmd1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'Command',
'metadata': {'name': 'cmd1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -94,16 +68,7 @@ async def test_plugin_list_filter_by_component_kinds():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'EventListener', 'metadata': {'name': 'listener1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'EventListener',
'metadata': {'name': 'listener1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -116,23 +81,9 @@ async def test_plugin_list_filter_by_component_kinds():
} }
}, },
'components': [ 'components': [
{ {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever2'}}}},
'manifest': { {'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool2'}}}},
'manifest': { ],
'kind': 'KnowledgeRetriever',
'metadata': {'name': 'retriever2'}
}
}
},
{
'manifest': {
'manifest': {
'kind': 'Tool',
'metadata': {'name': 'tool2'}
}
}
}
]
}, },
] ]
@@ -187,16 +138,7 @@ async def test_plugin_list_filter_no_filter():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'Tool',
'metadata': {'name': 'tool1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -209,15 +151,8 @@ async def test_plugin_list_filter_no_filter():
} }
}, },
'components': [ 'components': [
{ {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
'manifest': { ],
'manifest': {
'kind': 'KnowledgeRetriever',
'metadata': {'name': 'retriever1'}
}
}
}
]
}, },
] ]
@@ -267,15 +202,8 @@ async def test_plugin_list_filter_empty_result():
} }
}, },
'components': [ 'components': [
{ {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}}
'manifest': { ],
'manifest': {
'kind': 'KnowledgeRetriever',
'metadata': {'name': 'retriever1'}
}
}
}
]
}, },
] ]
@@ -321,16 +249,7 @@ async def test_plugin_list_filter_plugin_without_components():
} }
} }
}, },
'components': [ 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}],
{
'manifest': {
'manifest': {
'kind': 'Tool',
'metadata': {'name': 'tool1'}
}
}
}
]
}, },
{ {
'debug': False, 'debug': False,
@@ -342,7 +261,7 @@ async def test_plugin_list_filter_plugin_without_components():
} }
} }
}, },
'components': [] 'components': [],
}, },
] ]