diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..e1d89c1e --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,60 @@ +name: Lint + +on: + push: + branches: + - main + - master + - dev + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + +jobs: + ruff: + name: Ruff Lint & Format + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install uv + uses: astral-sh/setup-uv@v4 + + - name: Install dependencies + run: uv sync --dev + + - name: Run ruff check + run: uv run ruff check src + + - name: Run ruff format + run: uv run ruff format src --check + + frontend: + name: Frontend Lint + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '25' + + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + version: 9 + + - name: Install dependencies + working-directory: web + run: pnpm install + + - name: Run lint + working-directory: web + run: pnpm lint diff --git a/src/langbot/libs/qq_official_api/api.py b/src/langbot/libs/qq_official_api/api.py index e4d4e468..51a56d53 100644 --- a/src/langbot/libs/qq_official_api/api.py +++ b/src/langbot/libs/qq_official_api/api.py @@ -85,7 +85,6 @@ class QQOfficialClient: req: Quart Request 对象 """ try: - body = await req.get_data() print(f'[QQ Official] Received request, body length: {len(body)}') @@ -96,7 +95,6 @@ class QQOfficialClient: payload = json.loads(body) - if payload.get('op') == 13: validation_data = payload.get('d') if not validation_data: @@ -276,21 +274,21 @@ class QQOfficialClient: seed = bot_secret while len(seed) < target_size: seed *= 2 - return seed[:target_size].encode("utf-8") + return seed[:target_size].encode('utf-8') async def verify(self, validation_payload: dict): seed = await self.repeat_seed(self.secret) private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed) - event_ts = validation_payload.get("event_ts", "") - plain_token = validation_payload.get("plain_token", "") + event_ts = validation_payload.get('event_ts', '') + plain_token = validation_payload.get('plain_token', '') msg = event_ts + plain_token # sign signature = private_key.sign(msg.encode()).hex() response = { - "plain_token": plain_token, - "signature": signature, + 'plain_token': plain_token, + 'signature': signature, } return response diff --git a/src/langbot/libs/wecom_ai_bot_api/wecombotevent.py b/src/langbot/libs/wecom_ai_bot_api/wecombotevent.py index 75c6bbde..bc105cf8 100644 --- a/src/langbot/libs/wecom_ai_bot_api/wecombotevent.py +++ b/src/langbot/libs/wecom_ai_bot_api/wecombotevent.py @@ -36,7 +36,12 @@ class WecomBotEvent(dict): """ 用户名称 """ - return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid + return ( + self.get('username', '') + or self.get('from', {}).get('alias', '') + or self.get('from', {}).get('name', '') + or self.userid + ) @property def chatname(self) -> str: @@ -121,7 +126,7 @@ class WecomBotEvent(dict): 消息id """ return self.get('msgid', '') - + @property def ai_bot_id(self) -> str: """ diff --git a/src/langbot/pkg/api/http/controller/groups/webhooks.py b/src/langbot/pkg/api/http/controller/groups/webhooks.py index 0964076f..ec46c744 100644 --- a/src/langbot/pkg/api/http/controller/groups/webhooks.py +++ b/src/langbot/pkg/api/http/controller/groups/webhooks.py @@ -30,7 +30,6 @@ class WebhookRouterGroup(group.RouterGroup): 适配器返回的响应 """ try: - runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid) if not runtime_bot: @@ -39,11 +38,9 @@ class WebhookRouterGroup(group.RouterGroup): if not runtime_bot.enable: return quart.jsonify({'error': 'Bot is disabled'}), 403 - if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'): return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501 - response = await runtime_bot.adapter.handle_unified_webhook( bot_uuid=bot_uuid, path=path, diff --git a/src/langbot/pkg/api/http/service/bot.py b/src/langbot/pkg/api/http/service/bot.py index ac7ec13a..0632935b 100644 --- a/src/langbot/pkg/api/http/service/bot.py +++ b/src/langbot/pkg/api/http/service/bot.py @@ -59,7 +59,16 @@ class BotService: adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id # Webhook URL for unified webhook adapters (independent of bot running state) - if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack', 'wecomcs', 'LINE', 'lark']: + if persistence_bot['adapter'] in [ + 'wecom', + 'wecombot', + 'officialaccount', + 'qqofficial', + 'slack', + 'wecomcs', + 'LINE', + 'lark', + ]: webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300') webhook_url = f'/bots/{bot_uuid}' adapter_runtime_values['webhook_url'] = webhook_url diff --git a/src/langbot/pkg/core/stages/build_app.py b/src/langbot/pkg/core/stages/build_app.py index e7226041..791b5a9e 100644 --- a/src/langbot/pkg/core/stages/build_app.py +++ b/src/langbot/pkg/core/stages/build_app.py @@ -34,7 +34,6 @@ from .. import taskmgr from ...telemetry import telemetry as telemetry_module - @stage.stage_class('BuildAppStage') class BuildAppStage(stage.BootingStage): """Build LangBot application""" diff --git a/src/langbot/pkg/platform/sources/lark.py b/src/langbot/pkg/platform/sources/lark.py index 3c13c019..f123889c 100644 --- a/src/langbot/pkg/platform/sources/lark.py +++ b/src/langbot/pkg/platform/sources/lark.py @@ -244,7 +244,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time)) - if message.message_type == 'text': element_list = [] @@ -310,7 +309,11 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): ] elif message.message_type == 'audio': message_content['content'] = [ - {'tag': 'audio', 'file_key': message_content['file_key'], "duration": message_content.get('duration',0)} + { + 'tag': 'audio', + 'file_key': message_content['file_key'], + 'duration': message_content.get('duration', 0), + } ] for ele in message_content['content']: @@ -367,12 +370,9 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): audio_bytes = response.file.read() audio_base64 = base64.b64encode(audio_bytes).decode() - # Get content type from response headers content_type = response.raw.headers.get('content-type', 'audio/mpeg') - - mime_main = content_type.split(';')[0].strip() ext = mimetypes.guess_extension(mime_main) or '.bin' temp_dir = tempfile.gettempdir() @@ -418,7 +418,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): file_bytes = response.file.read() file_base64 = base64.b64encode(file_bytes).decode() - file_format = response.raw.headers['content-type'] file_size = len(file_bytes) @@ -453,7 +452,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): ) ) - return platform_message.MessageChain(lb_msg_list) diff --git a/src/langbot/pkg/platform/sources/wecom.py b/src/langbot/pkg/platform/sources/wecom.py index dc14a9b7..7bed676f 100644 --- a/src/langbot/pkg/platform/sources/wecom.py +++ b/src/langbot/pkg/platform/sources/wecom.py @@ -18,52 +18,52 @@ import langbot_plugin.api.entities.builtin.platform.entities as platform_entitie def split_string_by_bytes(text, limit=2048, encoding='utf-8'): """ Splits a string into a list of strings, where each part is at most 'limit' bytes. - + Args: text (str): The original string to split. limit (int): The maximum byte size for each split part. encoding (str): The encoding to use (default is 'utf-8'). - + Returns: list: A list of split strings. """ # 1. Encode the entire string into bytes bytes_data = text.encode(encoding) total_len = len(bytes_data) - + parts = [] start = 0 - + while start < total_len: # 2. Determine the end index for the current chunk # It shouldn't exceed the total length end = min(start + limit, total_len) - + # 3. Slice the byte array chunk = bytes_data[start:end] - + # 4. Attempt to decode the chunk # Use errors='ignore' to drop any partial bytes at the end of the chunk # (e.g., if a 3-byte character was cut after the 2nd byte) part_str = chunk.decode(encoding, errors='ignore') - + # 5. Calculate the actual byte length of the successfully decoded string # This tells us exactly where the valid character boundary ended part_bytes = part_str.encode(encoding) part_len = len(part_bytes) - + # Safety check: Prevent infinite loop if limit is too small (e.g., limit=1 for a Chinese char) if part_len == 0 and end < total_len: # Force advance by 1 byte to consume the un-decodable byte or raise error # Here we just treat it as a part to avoid stuck loops, though it might be invalid - start += 1 + start += 1 continue parts.append(part_str) - + # 6. Move the start pointer by the actual length consumed start += part_len - + return parts @@ -75,13 +75,15 @@ class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter): for msg in message_chain: if type(msg) is platform_message.Plain: chunks = split_string_by_bytes(msg.text) - content_list.extend([ - { - 'type': 'text', - 'content': chunk, - } - for chunk in chunks - ]) + content_list.extend( + [ + { + 'type': 'text', + 'content': chunk, + } + for chunk in chunks + ] + ) elif type(msg) is platform_message.Image: content_list.append( { diff --git a/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.py b/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.py index a181d51c..7fd98d69 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.py @@ -51,9 +51,10 @@ class SeekDBEmbedding(requester.ProviderAPIRequester): await self.initialize() if self._embedding_function is None: - raise RuntimeError("SeekDB embedding function initialization failed") + raise RuntimeError('SeekDB embedding function initialization failed') return self._embedding_function(input_text) except Exception as e: from .. import errors + raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}') diff --git a/src/langbot/pkg/provider/runners/localagent.py b/src/langbot/pkg/provider/runners/localagent.py index 5e5be37c..64dbfc63 100644 --- a/src/langbot/pkg/provider/runners/localagent.py +++ b/src/langbot/pkg/provider/runners/localagent.py @@ -218,10 +218,14 @@ class LocalAgentRunner(runner.RequestRunner): parameters = {} func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query) - + # Handle return value content tool_content = None - if isinstance(func_ret, list) and len(func_ret) > 0 and isinstance(func_ret[0], provider_message.ContentElement): + if ( + isinstance(func_ret, list) + and len(func_ret) > 0 + and isinstance(func_ret[0], provider_message.ContentElement) + ): tool_content = func_ret else: tool_content = json.dumps(func_ret, ensure_ascii=False) diff --git a/src/langbot/pkg/provider/runners/n8nsvapi.py b/src/langbot/pkg/provider/runners/n8nsvapi.py index 89cb6679..d7ec3ccb 100644 --- a/src/langbot/pkg/provider/runners/n8nsvapi.py +++ b/src/langbot/pkg/provider/runners/n8nsvapi.py @@ -68,15 +68,16 @@ class N8nServiceAPIRunner(runner.RequestRunner): return plain_text - async def _process_stream_response(self, response: aiohttp.ClientResponse) -> typing.AsyncGenerator[ - provider_message.Message, None]: + async def _process_stream_response( + self, response: aiohttp.ClientResponse + ) -> typing.AsyncGenerator[provider_message.Message, None]: """处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况""" - full_content = "" + full_content = '' chunk_idx = 0 is_final = False message_idx = 0 - buffer = "" + buffer = '' decoder = json.JSONDecoder() async for raw_chunk in response.content.iter_chunked(1024): @@ -129,7 +130,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): preview = chunk_str[:200] except Exception: preview = '' - self.ap.logger.warning(f"Failed to process chunk: {e}; chunk preview: {preview}") + self.ap.logger.warning(f'Failed to process chunk: {e}; chunk preview: {preview}') # 流结束后,尝试解析残余 buffer if buffer: @@ -151,7 +152,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): ) except Exception as e: preview = buffer[:200] - self.ap.logger.warning(f"Failed to parse remaining buffer: {e}; buffer preview: {preview}") + self.ap.logger.warning(f'Failed to parse remaining buffer: {e}; buffer preview: {preview}') async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: """调用n8n webhook""" @@ -165,7 +166,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): # 准备请求数据 payload = { # 基本消息内容 - 'chatInput' :plain_text, # 考虑到之前用户直接用的message model这里添加新键 + 'chatInput': plain_text, # 考虑到之前用户直接用的message model这里添加新键 'message': plain_text, 'user_message_text': plain_text, 'conversation_id': query.session.using_conversation.uuid, @@ -217,57 +218,49 @@ class N8nServiceAPIRunner(runner.RequestRunner): # 调用webhook async with aiohttp.ClientSession() as session: - if is_stream: - # 流式请求 - async with session.post( - self.webhook_url, - json=payload, - headers=headers, - auth=auth, - timeout=self.timeout - ) as response: + if is_stream: + # 流式请求 + async with session.post( + self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout + ) as response: + if response.status != 200: + error_text = await response.text() + self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') + raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') + + # 处理流式响应 + async for chunk in self._process_stream_response(response): + yield chunk + else: + async with session.post( + self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout + ) as response: + try: + async for chunk in self._process_stream_response(response): + output_content = chunk.content if chunk.is_final else '' + except: + # 非流式请求(保持原有逻辑) if response.status != 200: error_text = await response.text() self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') - # 处理流式响应 - async for chunk in self._process_stream_response(response): - yield chunk - else: - async with session.post( - self.webhook_url, - json=payload, - headers=headers, - auth=auth, - timeout=self.timeout - ) as response: - try: - async for chunk in self._process_stream_response(response): - output_content = chunk.content if chunk.is_final else '' - except: - # 非流式请求(保持原有逻辑) - if response.status != 200: - error_text = await response.text() - self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') - raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') + # 解析响应 + response_data = await response.json() + self.ap.logger.debug(f'n8n webhook response: {response_data}') - # 解析响应 - response_data = await response.json() - self.ap.logger.debug(f'n8n webhook response: {response_data}') + # 从响应中提取输出 + if self.output_key in response_data: + output_content = response_data[self.output_key] + else: + # 如果没有指定的输出键,则使用整个响应 + output_content = json.dumps(response_data, ensure_ascii=False) - # 从响应中提取输出 - if self.output_key in response_data: - output_content = response_data[self.output_key] - else: - # 如果没有指定的输出键,则使用整个响应 - output_content = json.dumps(response_data, ensure_ascii=False) - - # 返回消息 - yield provider_message.Message( - role='assistant', - content=output_content, - ) + # 返回消息 + yield provider_message.Message( + role='assistant', + content=output_content, + ) except Exception as e: self.ap.logger.error(f'n8n webhook call exception: {str(e)}') raise N8nAPIError(f'n8n webhook call exception: {str(e)}') @@ -275,4 +268,4 @@ class N8nServiceAPIRunner(runner.RequestRunner): async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: """运行请求""" async for msg in self._call_webhook(query): - yield msg \ No newline at end of file + yield msg diff --git a/src/langbot/pkg/provider/tools/loaders/mcp.py b/src/langbot/pkg/provider/tools/loaders/mcp.py index 4b0583c6..46d63b84 100644 --- a/src/langbot/pkg/provider/tools/loaders/mcp.py +++ b/src/langbot/pkg/provider/tools/loaders/mcp.py @@ -194,7 +194,7 @@ class RuntimeMCPSession: async def func(*, _tool=tool, **kwargs): if not self.session: - raise Exception("MCP session is not connected") + raise Exception('MCP session is not connected') result = await self.session.call_tool(_tool.name, kwargs) if result.isError: @@ -202,8 +202,8 @@ class RuntimeMCPSession: for content in result.content: if content.type == 'text': error_texts.append(content.text) - raise Exception("\n".join(error_texts) if error_texts else "Unknown error from MCP tool") - + raise Exception('\n'.join(error_texts) if error_texts else 'Unknown error from MCP tool') + result_contents: list[provider_message.ContentElement] = [] for content in result.content: if content.type == 'text': @@ -213,7 +213,7 @@ class RuntimeMCPSession: elif content.type == 'resource': # TODO: Handle resource content pass - + return result_contents func.__name__ = tool.name @@ -221,8 +221,8 @@ class RuntimeMCPSession: self.functions.append( resource_tool.LLMTool( name=tool.name, - human_desc=tool.description or "", - description=tool.description or "", + human_desc=tool.description or '', + description=tool.description or '', parameters=tool.inputSchema, func=func, ) @@ -338,13 +338,10 @@ class MCPLoader(loader.ToolLoader): """ uuid_ = server_config.get('uuid') if not uuid_: - self.ap.logger.warning( - 'Server UUID is None for MCP server, maybe testing in the config page.' - ) + self.ap.logger.warning('Server UUID is None for MCP server, maybe testing in the config page.') uuid_ = str(uuid_module.uuid4()) server_config['uuid'] = uuid_ - name = server_config['name'] uuid = server_config['uuid'] mode = server_config['mode'] diff --git a/src/langbot/pkg/rag/knowledge/services/embedder.py b/src/langbot/pkg/rag/knowledge/services/embedder.py index f93382ff..485bc21e 100644 --- a/src/langbot/pkg/rag/knowledge/services/embedder.py +++ b/src/langbot/pkg/rag/knowledge/services/embedder.py @@ -35,9 +35,9 @@ class Embedder(BaseService): # get embeddings (batch size limit: 64 for OpenAI) MAX_BATCH_SIZE = 64 embeddings_list: list[list[float]] = [] - + for i in range(0, len(chunks), MAX_BATCH_SIZE): - batch = chunks[i:i + MAX_BATCH_SIZE] + batch = chunks[i : i + MAX_BATCH_SIZE] batch_embeddings = await embedding_model.provider.requester.invoke_embedding( model=embedding_model, input_text=batch, diff --git a/src/langbot/pkg/vector/mgr.py b/src/langbot/pkg/vector/mgr.py index f95f5f75..f0cb742c 100644 --- a/src/langbot/pkg/vector/mgr.py +++ b/src/langbot/pkg/vector/mgr.py @@ -55,12 +55,7 @@ class VectorDBManager: user = pgvector_config.get('user', 'postgres') password = pgvector_config.get('password', 'postgres') self.vector_db = PgVectorDatabase( - self.ap, - host=host, - port=port, - database=database, - user=user, - password=password + self.ap, host=host, port=port, database=database, user=user, password=password ) self.ap.logger.info('Initialized pgvector database backend.') diff --git a/src/langbot/pkg/vector/vdbs/milvus.py b/src/langbot/pkg/vector/vdbs/milvus.py index f15071c4..2852dea1 100644 --- a/src/langbot/pkg/vector/vdbs/milvus.py +++ b/src/langbot/pkg/vector/vdbs/milvus.py @@ -10,7 +10,7 @@ from langbot.pkg.core import app class MilvusVectorDatabase(VectorDatabase): """Milvus vector database implementation""" - def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None, db_name: str = None): + def __init__(self, ap: app.Application, uri: str = 'milvus.db', token: str = None, db_name: str = None): """Initialize Milvus vector database Args: @@ -34,32 +34,32 @@ class MilvusVectorDatabase(VectorDatabase): self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name) else: self.client = MilvusClient(uri=self.uri, db_name=self.db_name) - self.ap.logger.info(f"Connected to Milvus at {self.uri}") + self.ap.logger.info(f'Connected to Milvus at {self.uri}') except Exception as e: - self.ap.logger.error(f"Failed to connect to Milvus: {e}") + self.ap.logger.error(f'Failed to connect to Milvus: {e}') raise @staticmethod def _normalize_collection_name(collection: str) -> str: """Normalize collection name to comply with Milvus naming requirements. - + Milvus requirements: - First character must be an underscore or letter - Can only contain numbers, letters and underscores - + Args: collection: Original collection name (e.g., UUID with hyphens) - + Returns: Normalized collection name that complies with Milvus requirements """ # Replace hyphens with underscores normalized = collection.replace('-', '_') - + # If first character is not a letter or underscore, prepend 'kb_' if normalized and not (normalized[0].isalpha() or normalized[0] == '_'): normalized = 'kb_' + normalized - + return normalized async def _ensure_vector_index(self, collection: str) -> None: @@ -70,15 +70,11 @@ class MilvusVectorDatabase(VectorDatabase): """ index_params = IndexParams() index_params.add_index( - field_name="vector", - index_type="AUTOINDEX", - metric_type="COSINE", - ) - await asyncio.to_thread( - self.client.create_index, - collection_name=collection, - index_params=index_params + field_name='vector', + index_type='AUTOINDEX', + metric_type='COSINE', ) + await asyncio.to_thread(self.client.create_index, collection_name=collection, index_params=index_params) async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None): """Internal method to get or create a Milvus collection with proper configuration. @@ -89,14 +85,12 @@ class MilvusVectorDatabase(VectorDatabase): """ # Normalize collection name for Milvus compatibility collection = self._normalize_collection_name(collection) - + if collection in self._collections: return collection # Check if collection exists - has_collection = await asyncio.to_thread( - self.client.has_collection, collection_name=collection - ) + has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection) if not has_collection: # Default dimension if not specified (for backward compatibility) @@ -104,24 +98,26 @@ class MilvusVectorDatabase(VectorDatabase): vector_size = 1536 fields = [ - FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255), - FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size), - FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), - FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255), - FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255), + FieldSchema(name='id', dtype=DataType.VARCHAR, is_primary=True, max_length=255), + FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=vector_size), + FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535), + FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=255), + FieldSchema(name='chunk_uuid', dtype=DataType.VARCHAR, max_length=255), ] - schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors") + schema = CollectionSchema(fields=fields, description='LangBot knowledge base vectors') await asyncio.to_thread( self.client.create_collection, collection_name=collection, schema=schema, - metric_type="COSINE", + metric_type='COSINE', ) await self._ensure_vector_index(collection) - self.ap.logger.info(f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX") + self.ap.logger.info( + f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX" + ) else: # Ensure index exists for existing collection await self._ensure_index_if_missing(collection) @@ -137,11 +133,8 @@ class MilvusVectorDatabase(VectorDatabase): collection: Normalized collection name """ try: - indexes = await asyncio.to_thread( - self.client.list_indexes, - collection_name=collection - ) - if "vector" not in indexes: + indexes = await asyncio.to_thread(self.client.list_indexes, collection_name=collection) + if 'vector' not in indexes: await self._ensure_vector_index(collection) self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'") except Exception as e: @@ -172,7 +165,7 @@ class MilvusVectorDatabase(VectorDatabase): metadatas: List of metadata dictionaries for each vector """ collection = self._normalize_collection_name(collection) - + if not embeddings_list: return @@ -184,39 +177,30 @@ class MilvusVectorDatabase(VectorDatabase): data = [] for i, vector_id in enumerate(ids): entry = { - "id": vector_id, - "vector": embeddings_list[i], + 'id': vector_id, + 'vector': embeddings_list[i], } # Add metadata fields if metadatas and i < len(metadatas): metadata = metadatas[i] # Add common metadata fields - if "text" in metadata: - entry["text"] = metadata["text"] - if "file_id" in metadata: - entry["file_id"] = metadata["file_id"] - if "uuid" in metadata: - entry["chunk_uuid"] = metadata["uuid"] + if 'text' in metadata: + entry['text'] = metadata['text'] + if 'file_id' in metadata: + entry['file_id'] = metadata['file_id'] + if 'uuid' in metadata: + entry['chunk_uuid'] = metadata['uuid'] data.append(entry) # Insert data into Milvus - await asyncio.to_thread( - self.client.insert, - collection_name=collection, - data=data - ) + await asyncio.to_thread(self.client.insert, collection_name=collection, data=data) # Load collection for searching (Milvus requires this) - await asyncio.to_thread( - self.client.load_collection, - collection_name=collection - ) + await asyncio.to_thread(self.client.load_collection, collection_name=collection) self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'") - async def search( - self, collection: str, query_embedding: list[float], k: int = 5 - ) -> Dict[str, Any]: + async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]: """Search for similar vectors in Milvus collection Args: @@ -231,10 +215,7 @@ class MilvusVectorDatabase(VectorDatabase): await self.get_or_create_collection(collection) # Perform search - search_params = { - "metric_type": "COSINE", - "params": {} - } + search_params = {'metric_type': 'COSINE', 'params': {}} results = await asyncio.to_thread( self.client.search, @@ -242,7 +223,7 @@ class MilvusVectorDatabase(VectorDatabase): data=[query_embedding], limit=k, search_params=search_params, - output_fields=["text", "file_id", "chunk_uuid"] + output_fields=['text', 'file_id', 'chunk_uuid'], ) # Convert results to Chroma-compatible format @@ -253,30 +234,24 @@ class MilvusVectorDatabase(VectorDatabase): if results and len(results) > 0: for hit in results[0]: - ids.append(hit.get("id", "")) - distances.append(hit.get("distance", 0.0)) + ids.append(hit.get('id', '')) + distances.append(hit.get('distance', 0.0)) # Build metadata from entity fields - entity = hit.get("entity", {}) + entity = hit.get('entity', {}) metadata = {} - if "text" in entity: - metadata["text"] = entity["text"] - if "file_id" in entity: - metadata["file_id"] = entity["file_id"] - if "chunk_uuid" in entity: - metadata["uuid"] = entity["chunk_uuid"] + if 'text' in entity: + metadata['text'] = entity['text'] + if 'file_id' in entity: + metadata['file_id'] = entity['file_id'] + if 'chunk_uuid' in entity: + metadata['uuid'] = entity['chunk_uuid'] metadatas.append(metadata) # Return in Chroma-compatible format (nested lists) - result = { - "ids": [ids], - "distances": [distances], - "metadatas": [metadatas] - } + result = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]} - self.ap.logger.info( - f"Milvus search in '{collection}' returned {len(ids)} results" - ) + self.ap.logger.info(f"Milvus search in '{collection}' returned {len(ids)} results") return result async def delete_by_file_id(self, collection: str, file_id: str) -> None: @@ -290,14 +265,8 @@ class MilvusVectorDatabase(VectorDatabase): await self.get_or_create_collection(collection) # Delete entities matching the file_id - await asyncio.to_thread( - self.client.delete, - collection_name=collection, - filter=f'file_id == "{file_id}"' - ) - self.ap.logger.info( - f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}" - ) + await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"') + self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}") async def delete_collection(self, collection: str): """Delete a Milvus collection @@ -306,18 +275,14 @@ class MilvusVectorDatabase(VectorDatabase): collection: Collection name to delete """ collection = self._normalize_collection_name(collection) - + self._collections.discard(collection) # Check if collection exists before attempting deletion - has_collection = await asyncio.to_thread( - self.client.has_collection, collection_name=collection - ) + has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection) if has_collection: - await asyncio.to_thread( - self.client.drop_collection, collection_name=collection - ) + await asyncio.to_thread(self.client.drop_collection, collection_name=collection) self.ap.logger.info(f"Deleted Milvus collection '{collection}'") else: self.ap.logger.warning(f"Milvus collection '{collection}' not found") diff --git a/tests/unit_tests/config/test_env_override.py b/tests/unit_tests/config/test_env_override.py index d20988e9..0e309d4c 100644 --- a/tests/unit_tests/config/test_env_override.py +++ b/tests/unit_tests/config/test_env_override.py @@ -9,27 +9,28 @@ from typing import Any def _apply_env_overrides_to_config(cfg: dict) -> dict: """Apply environment variable overrides to data/config.yaml - - Environment variables should be uppercase and use __ (double underscore) + + Environment variables should be uppercase and use __ (double underscore) to represent nested keys. For example: - CONCURRENCY__PIPELINE overrides concurrency.pipeline - PLUGIN__RUNTIME_WS_URL overrides plugin.runtime_ws_url - + Arrays and dict types are ignored. - + Args: cfg: Configuration dictionary - + Returns: Updated configuration dictionary """ + def convert_value(value: str, original_value: Any) -> Any: """Convert string value to appropriate type based on original value - + Args: value: String value from environment variable original_value: Original value to infer type from - + Returns: Converted value (falls back to string if conversion fails) """ @@ -49,7 +50,7 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict: return value else: return value - + # Process environment variables for env_key, env_value in os.environ.items(): # Check if the environment variable is uppercase and contains __ @@ -57,18 +58,18 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict: continue if '__' not in env_key: continue - + # Convert environment variable name to config path # e.g., CONCURRENCY__PIPELINE -> ['concurrency', 'pipeline'] keys = [key.lower() for key in env_key.split('__')] - + # Navigate to the target value and validate the path current = cfg - + for i, key in enumerate(keys): if not isinstance(current, dict) or key not in current: break - + if i == len(keys) - 1: # At the final key - check if it's a scalar value if isinstance(current[key], (dict, list)): @@ -81,248 +82,182 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict: else: # Navigate deeper current = current[key] - + return cfg class TestEnvOverrides: """Test environment variable override functionality""" - + def test_simple_string_override(self): """Test overriding a simple string value""" - cfg = { - 'api': { - 'port': 5300 - } - } - + cfg = {'api': {'port': 5300}} + # Set environment variable os.environ['API__PORT'] = '8080' - + result = _apply_env_overrides_to_config(cfg) - + assert result['api']['port'] == 8080 - + # Cleanup del os.environ['API__PORT'] - + def test_nested_key_override(self): """Test overriding nested keys with __ delimiter""" - cfg = { - 'concurrency': { - 'pipeline': 20, - 'session': 1 - } - } - + cfg = {'concurrency': {'pipeline': 20, 'session': 1}} + os.environ['CONCURRENCY__PIPELINE'] = '50' - + result = _apply_env_overrides_to_config(cfg) - + assert result['concurrency']['pipeline'] == 50 assert result['concurrency']['session'] == 1 # Unchanged - + del os.environ['CONCURRENCY__PIPELINE'] - + def test_deep_nested_override(self): """Test overriding deeply nested keys""" - cfg = { - 'system': { - 'jwt': { - 'expire': 604800, - 'secret': '' - } - } - } - + cfg = {'system': {'jwt': {'expire': 604800, 'secret': ''}}} + os.environ['SYSTEM__JWT__EXPIRE'] = '86400' os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key' - + result = _apply_env_overrides_to_config(cfg) - + assert result['system']['jwt']['expire'] == 86400 assert result['system']['jwt']['secret'] == 'my_secret_key' - + del os.environ['SYSTEM__JWT__EXPIRE'] del os.environ['SYSTEM__JWT__SECRET'] - + def test_underscore_in_key(self): """Test keys with underscores like runtime_ws_url""" - cfg = { - 'plugin': { - 'enable': True, - 'runtime_ws_url': 'ws://localhost:5400/control/ws' - } - } - + cfg = {'plugin': {'enable': True, 'runtime_ws_url': 'ws://localhost:5400/control/ws'}} + os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws' - + result = _apply_env_overrides_to_config(cfg) - + assert result['plugin']['runtime_ws_url'] == 'ws://newhost:6000/ws' - + del os.environ['PLUGIN__RUNTIME_WS_URL'] - + def test_boolean_conversion(self): """Test boolean value conversion""" - cfg = { - 'plugin': { - 'enable': True, - 'enable_marketplace': False - } - } - + cfg = {'plugin': {'enable': True, 'enable_marketplace': False}} + os.environ['PLUGIN__ENABLE'] = 'false' os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true' - + result = _apply_env_overrides_to_config(cfg) - + assert result['plugin']['enable'] is False assert result['plugin']['enable_marketplace'] is True - + del os.environ['PLUGIN__ENABLE'] del os.environ['PLUGIN__ENABLE_MARKETPLACE'] - + def test_ignore_dict_type(self): """Test that dict types are ignored""" - cfg = { - 'database': { - 'use': 'sqlite', - 'sqlite': { - 'path': 'data/langbot.db' - } - } - } - + cfg = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}} + # Try to override a dict value - should be ignored os.environ['DATABASE__SQLITE'] = 'new_value' - + result = _apply_env_overrides_to_config(cfg) - + # Should remain a dict, not overridden assert isinstance(result['database']['sqlite'], dict) assert result['database']['sqlite']['path'] == 'data/langbot.db' - + del os.environ['DATABASE__SQLITE'] - + def test_ignore_list_type(self): """Test that list/array types are ignored""" - cfg = { - 'admins': ['admin1', 'admin2'], - 'command': { - 'enable': True, - 'prefix': ['!', '!'] - } - } - + cfg = {'admins': ['admin1', 'admin2'], 'command': {'enable': True, 'prefix': ['!', '!']}} + # Try to override list values - should be ignored os.environ['ADMINS'] = 'admin3' os.environ['COMMAND__PREFIX'] = '?' - + result = _apply_env_overrides_to_config(cfg) - + # Should remain lists, not overridden assert isinstance(result['admins'], list) assert result['admins'] == ['admin1', 'admin2'] assert isinstance(result['command']['prefix'], list) assert result['command']['prefix'] == ['!', '!'] - + del os.environ['ADMINS'] del os.environ['COMMAND__PREFIX'] - + def test_lowercase_env_var_ignored(self): """Test that lowercase environment variables are ignored""" - cfg = { - 'api': { - 'port': 5300 - } - } - + cfg = {'api': {'port': 5300}} + os.environ['api__port'] = '8080' - + result = _apply_env_overrides_to_config(cfg) - + # Should not be overridden assert result['api']['port'] == 5300 - + del os.environ['api__port'] - + def test_no_double_underscore_ignored(self): """Test that env vars without __ are ignored""" - cfg = { - 'api': { - 'port': 5300 - } - } - + cfg = {'api': {'port': 5300}} + os.environ['APIPORT'] = '8080' - + result = _apply_env_overrides_to_config(cfg) - + # Should not be overridden assert result['api']['port'] == 5300 - + del os.environ['APIPORT'] - + def test_nonexistent_key_ignored(self): """Test that env vars for non-existent keys are ignored""" - cfg = { - 'api': { - 'port': 5300 - } - } - + cfg = {'api': {'port': 5300}} + os.environ['API__NONEXISTENT'] = 'value' - + result = _apply_env_overrides_to_config(cfg) - + # Should not create new key assert 'nonexistent' not in result['api'] - + del os.environ['API__NONEXISTENT'] - + def test_integer_conversion(self): """Test integer value conversion""" - cfg = { - 'concurrency': { - 'pipeline': 20 - } - } - + cfg = {'concurrency': {'pipeline': 20}} + os.environ['CONCURRENCY__PIPELINE'] = '100' - + result = _apply_env_overrides_to_config(cfg) - + assert result['concurrency']['pipeline'] == 100 assert isinstance(result['concurrency']['pipeline'], int) - + del os.environ['CONCURRENCY__PIPELINE'] - + def test_multiple_overrides(self): """Test multiple environment variable overrides at once""" - cfg = { - 'api': { - 'port': 5300 - }, - 'concurrency': { - 'pipeline': 20, - 'session': 1 - }, - 'plugin': { - 'enable': False - } - } - + cfg = {'api': {'port': 5300}, 'concurrency': {'pipeline': 20, 'session': 1}, 'plugin': {'enable': False}} + os.environ['API__PORT'] = '8080' os.environ['CONCURRENCY__PIPELINE'] = '50' os.environ['PLUGIN__ENABLE'] = 'true' - + result = _apply_env_overrides_to_config(cfg) - + assert result['api']['port'] == 8080 assert result['concurrency']['pipeline'] == 50 assert result['plugin']['enable'] is True - + del os.environ['API__PORT'] del os.environ['CONCURRENCY__PIPELINE'] del os.environ['PLUGIN__ENABLE'] diff --git a/tests/unit_tests/plugin/test_plugin_component_filtering.py b/tests/unit_tests/plugin/test_plugin_component_filtering.py index b83667c5..c2c4fd76 100644 --- a/tests/unit_tests/plugin/test_plugin_component_filtering.py +++ b/tests/unit_tests/plugin/test_plugin_component_filtering.py @@ -1,6 +1,5 @@ """Test plugin list filtering by component kinds.""" -from datetime import datetime from unittest.mock import AsyncMock, MagicMock import pytest @@ -31,16 +30,7 @@ async def test_plugin_list_filter_by_component_kinds(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'Tool', - 'metadata': {'name': 'tool1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}], }, { 'debug': False, @@ -53,15 +43,8 @@ async def test_plugin_list_filter_by_component_kinds(): } }, 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'KnowledgeRetriever', - 'metadata': {'name': 'retriever1'} - } - } - } - ] + {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}} + ], }, { 'debug': False, @@ -73,16 +56,7 @@ async def test_plugin_list_filter_by_component_kinds(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'Command', - 'metadata': {'name': 'cmd1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'Command', 'metadata': {'name': 'cmd1'}}}}], }, { 'debug': False, @@ -94,16 +68,7 @@ async def test_plugin_list_filter_by_component_kinds(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'EventListener', - 'metadata': {'name': 'listener1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'EventListener', 'metadata': {'name': 'listener1'}}}}], }, { 'debug': False, @@ -116,23 +81,9 @@ async def test_plugin_list_filter_by_component_kinds(): } }, 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'KnowledgeRetriever', - 'metadata': {'name': 'retriever2'} - } - } - }, - { - 'manifest': { - 'manifest': { - 'kind': 'Tool', - 'metadata': {'name': 'tool2'} - } - } - } - ] + {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever2'}}}}, + {'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool2'}}}}, + ], }, ] @@ -187,16 +138,7 @@ async def test_plugin_list_filter_no_filter(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'Tool', - 'metadata': {'name': 'tool1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}], }, { 'debug': False, @@ -209,15 +151,8 @@ async def test_plugin_list_filter_no_filter(): } }, 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'KnowledgeRetriever', - 'metadata': {'name': 'retriever1'} - } - } - } - ] + {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}} + ], }, ] @@ -267,15 +202,8 @@ async def test_plugin_list_filter_empty_result(): } }, 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'KnowledgeRetriever', - 'metadata': {'name': 'retriever1'} - } - } - } - ] + {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}} + ], }, ] @@ -321,16 +249,7 @@ async def test_plugin_list_filter_plugin_without_components(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'Tool', - 'metadata': {'name': 'tool1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}], }, { 'debug': False, @@ -342,7 +261,7 @@ async def test_plugin_list_filter_plugin_without_components(): } } }, - 'components': [] + 'components': [], }, ]