From 195b694ecca1b50b4d4226f196e263f38d0b7640 Mon Sep 17 00:00:00 2001 From: Tiankai Ma Date: Mon, 19 Jan 2026 23:42:17 +0800 Subject: [PATCH 1/6] feat(telegram): threaded mode support (#1920) * feat(telegram): reply in threaded mode * feat(telegram): thread-level isolation --- src/langbot/pkg/platform/botmgr.py | 18 ++++++++++++-- src/langbot/pkg/platform/sources/telegram.py | 26 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/langbot/pkg/platform/botmgr.py b/src/langbot/pkg/platform/botmgr.py index b9d7a5fe..43b8a7ab 100644 --- a/src/langbot/pkg/platform/botmgr.py +++ b/src/langbot/pkg/platform/botmgr.py @@ -75,10 +75,17 @@ class RuntimeBot: # Only add to query pool if no webhook requested to skip pipeline if not skip_pipeline: + launcher_id = event.sender.id + + if hasattr(adapter, 'get_launcher_id'): + custom_launcher_id = adapter.get_launcher_id(event) + if custom_launcher_id: + launcher_id = custom_launcher_id + await self.ap.query_pool.add_query( bot_uuid=self.bot_entity.uuid, launcher_type=provider_session.LauncherTypes.PERSON, - launcher_id=event.sender.id, + launcher_id=launcher_id, sender_id=event.sender.id, message_event=event, message_chain=event.message_chain, @@ -111,10 +118,17 @@ class RuntimeBot: # Only add to query pool if no webhook requested to skip pipeline if not skip_pipeline: + launcher_id = event.group.id + + if hasattr(adapter, 'get_launcher_id'): + custom_launcher_id = adapter.get_launcher_id(event) + if custom_launcher_id: + launcher_id = custom_launcher_id + await self.ap.query_pool.add_query( bot_uuid=self.bot_entity.uuid, launcher_type=provider_session.LauncherTypes.GROUP, - launcher_id=event.group.id, + launcher_id=launcher_id, sender_id=event.sender.id, message_event=event, message_chain=event.message_chain, diff --git a/src/langbot/pkg/platform/sources/telegram.py b/src/langbot/pkg/platform/sources/telegram.py index cfdbe75c..79a959fa 100644 --- a/src/langbot/pkg/platform/sources/telegram.py +++ b/src/langbot/pkg/platform/sources/telegram.py @@ -197,6 +197,10 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): } if self.config['markdown_card'] is True: args['parse_mode'] = 'MarkdownV2' + + if message_source.source_platform_object.message.message_thread_id: + args['message_thread_id'] = message_source.source_platform_object.message.message_thread_id + if quote_origin: args['reply_to_message_id'] = message_source.source_platform_object.message.id @@ -231,8 +235,12 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): 'chat_id': message_source.source_platform_object.effective_chat.id, 'text': content, } + if message_source.source_platform_object.message.message_thread_id: + args['message_thread_id'] = message_source.source_platform_object.message.message_thread_id + if quote_origin: args['reply_to_message_id'] = message_source.source_platform_object.message.id + if self.config['markdown_card'] is True: args['parse_mode'] = 'MarkdownV2' @@ -260,6 +268,24 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter): # self.seq = 1 # 消息回复结束之后重置seq self.msg_stream_id.pop(message_id) # 消息回复结束之后删除流式消息id + def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None: + if not isinstance(event.source_platform_object, Update): + return None + + message = event.source_platform_object.message + if not message: + return None + + # specifically handle telegram forum topic and private thread(not supported by official client yet but supported by bot api) + if message.message_thread_id: + # check if it is a group + if isinstance(event, platform_events.GroupMessage): + return f'{event.group.id}#{message.message_thread_id}' + elif isinstance(event, platform_events.FriendMessage): + return f'{event.sender.id}#{message.message_thread_id}' + + return None + async def is_stream_output_supported(self) -> bool: is_stream = False if self.config.get('enable-stream-reply', None): From 604cc53973a48693d7bfb6256c325cfa19b9a052 Mon Sep 17 00:00:00 2001 From: Tiankai Ma Date: Mon, 19 Jan 2026 23:42:47 +0800 Subject: [PATCH 2/6] fix(localagent): allow empty func arg (#1921) --- src/langbot/pkg/provider/runners/localagent.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/langbot/pkg/provider/runners/localagent.py b/src/langbot/pkg/provider/runners/localagent.py index b335ed11..5e5be37c 100644 --- a/src/langbot/pkg/provider/runners/localagent.py +++ b/src/langbot/pkg/provider/runners/localagent.py @@ -212,7 +212,10 @@ class LocalAgentRunner(runner.RequestRunner): try: func = tool_call.function - parameters = json.loads(func.arguments) + if func.arguments: + parameters = json.loads(func.arguments) + else: + parameters = {} func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query) From fe8a738cd7b5524c3d452bb99e284ec9b45bf666 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Tue, 20 Jan 2026 01:53:42 +0800 Subject: [PATCH 3/6] fix(i18n): update apiKeyCreatedMessage for clarity across multiple languages --- web/src/i18n/locales/en-US.ts | 3 ++- web/src/i18n/locales/ja-JP.ts | 3 ++- web/src/i18n/locales/zh-Hans.ts | 2 +- web/src/i18n/locales/zh-Hant.ts | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index cf79c55a..c097eb4e 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -121,7 +121,8 @@ const enUS = { webhookHint: 'Webhooks allow LangBot to push person and group message events to external systems', actions: 'Actions', - apiKeyCreatedMessage: 'Please copy this API key.', + apiKeyCreatedMessage: + 'Please copy this API key, if the button is invalid, please copy manually.', none: 'None', }, notFound: { diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index b812d937..903e8ea5 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -123,7 +123,8 @@ const jaJP = { webhookHint: 'Webhook を使用すると、LangBot は個人メッセージとグループメッセージイベントを外部システムにプッシュできます', actions: 'アクション', - apiKeyCreatedMessage: 'この API キーをコピーしてください。', + apiKeyCreatedMessage: + 'この API キーをコピーしてください。もしボタンが無効な場合は手動でコピーしてください。', none: 'なし', }, notFound: { diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 473a8401..6ce7a13a 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -114,7 +114,7 @@ const zhHans = { noWebhooks: '暂无 Webhook', webhookHint: 'Webhook 允许 LangBot 将个人消息和群消息事件推送到外部系统', actions: '操作', - apiKeyCreatedMessage: '请复制此 API 密钥。', + apiKeyCreatedMessage: '请复制此 API 密钥,若按钮无效,请手动复制。', none: '无', }, notFound: { diff --git a/web/src/i18n/locales/zh-Hant.ts b/web/src/i18n/locales/zh-Hant.ts index f4cf40b9..2dac2873 100644 --- a/web/src/i18n/locales/zh-Hant.ts +++ b/web/src/i18n/locales/zh-Hant.ts @@ -114,7 +114,7 @@ const zhHant = { noWebhooks: '暫無 Webhook', webhookHint: 'Webhook 允許 LangBot 將個人訊息和群組訊息事件推送到外部系統', actions: '操作', - apiKeyCreatedMessage: '請複製此 API 金鑰。', + apiKeyCreatedMessage: '請複製此 API 金鑰,若按鈕無效,請手動複製。', none: '無', }, notFound: { From c90f2d6a1244dd44658971d82c646845f5ca2f8b Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Tue, 20 Jan 2026 01:59:19 +0800 Subject: [PATCH 4/6] chore: update mcp dependency version to 1.25.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c3daa41e..a5aee60c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "pynacl>=1.5.0", # Required for Discord voice support "gewechat-client>=0.1.5", "lark-oapi>=1.4.15", - "mcp>=1.20.0", + "mcp>=1.25.0", "nakuru-project-idk>=0.0.2.1", "ollama>=0.4.8", "openai>1.0.0", From e60cb6ad0e5a97f942d4c9c5bb196ae33097f7c2 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Fri, 23 Jan 2026 13:30:44 +0800 Subject: [PATCH 5/6] fix: ruff check errors --- src/langbot/pkg/platform/botmgr.py | 4 +- src/langbot/pkg/platform/webhook_pusher.py | 4 +- src/langbot/pkg/vector/vdbs/pgvector_db.py | 96 ++++++++-------------- src/langbot/pkg/vector/vdbs/seekdb.py | 32 ++++---- 4 files changed, 52 insertions(+), 84 deletions(-) diff --git a/src/langbot/pkg/platform/botmgr.py b/src/langbot/pkg/platform/botmgr.py index 43b8a7ab..7540808d 100644 --- a/src/langbot/pkg/platform/botmgr.py +++ b/src/langbot/pkg/platform/botmgr.py @@ -93,7 +93,7 @@ class RuntimeBot: pipeline_uuid=self.bot_entity.use_pipeline_uuid, ) else: - await self.logger.info(f'Pipeline skipped for person message due to webhook response') + await self.logger.info('Pipeline skipped for person message due to webhook response') async def on_group_message( event: platform_events.GroupMessage, @@ -136,7 +136,7 @@ class RuntimeBot: pipeline_uuid=self.bot_entity.use_pipeline_uuid, ) else: - await self.logger.info(f'Pipeline skipped for group message due to webhook response') + await self.logger.info('Pipeline skipped for group message due to webhook response') self.adapter.register_listener(platform_events.FriendMessage, on_friend_message) self.adapter.register_listener(platform_events.GroupMessage, on_group_message) diff --git a/src/langbot/pkg/platform/webhook_pusher.py b/src/langbot/pkg/platform/webhook_pusher.py index 15d25733..5a8d2564 100644 --- a/src/langbot/pkg/platform/webhook_pusher.py +++ b/src/langbot/pkg/platform/webhook_pusher.py @@ -56,7 +56,7 @@ class WebhookPusher: # Check if any webhook responded with skip_pipeline=true for result in results: if isinstance(result, dict) and result.get('skip_pipeline') is True: - self.logger.info(f'Webhook responded with skip_pipeline=true, skipping pipeline for person message') + self.logger.info('Webhook responded with skip_pipeline=true, skipping pipeline for person message') return True return False @@ -103,7 +103,7 @@ class WebhookPusher: # Check if any webhook responded with skip_pipeline=true for result in results: if isinstance(result, dict) and result.get('skip_pipeline') is True: - self.logger.info(f'Webhook responded with skip_pipeline=true, skipping pipeline for group message') + self.logger.info('Webhook responded with skip_pipeline=true, skipping pipeline for group message') return True return False diff --git a/src/langbot/pkg/vector/vdbs/pgvector_db.py b/src/langbot/pkg/vector/vdbs/pgvector_db.py index 2669902b..7490f228 100644 --- a/src/langbot/pkg/vector/vdbs/pgvector_db.py +++ b/src/langbot/pkg/vector/vdbs/pgvector_db.py @@ -1,19 +1,18 @@ from __future__ import annotations -import asyncio from typing import Any, Dict from sqlalchemy import create_engine, text, Column, String, Text -from sqlalchemy.orm import declarative_base, sessionmaker, Session +from sqlalchemy.orm import declarative_base from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from pgvector.sqlalchemy import Vector from langbot.pkg.vector.vdb import VectorDatabase from langbot.pkg.core import app -import uuid Base = declarative_base() class PgVectorEntry(Base): """SQLAlchemy model for pgvector entries""" + __tablename__ = 'langbot_vectors' id = Column(String, primary_key=True) @@ -31,11 +30,11 @@ class PgVectorDatabase(VectorDatabase): self, ap: app.Application, connection_string: str = None, - host: str = "localhost", + host: str = 'localhost', port: int = 5432, - database: str = "langbot", - user: str = "postgres", - password: str = "postgres" + database: str = 'langbot', + user: str = 'postgres', + password: str = 'postgres', ): """Initialize pgvector database @@ -54,14 +53,10 @@ class PgVectorDatabase(VectorDatabase): if connection_string: self.connection_string = connection_string else: - self.connection_string = ( - f"postgresql+psycopg://{user}:{password}@{host}:{port}/{database}" - ) + self.connection_string = f'postgresql+psycopg://{user}:{password}@{host}:{port}/{database}' - self.async_connection_string = self.connection_string.replace( - "postgresql://", "postgresql+asyncpg://" - ).replace( - "postgresql+psycopg://", "postgresql+asyncpg://" + self.async_connection_string = self.connection_string.replace('postgresql://', 'postgresql+asyncpg://').replace( + 'postgresql+psycopg://', 'postgresql+asyncpg://' ) self.engine = None @@ -75,35 +70,25 @@ class PgVectorDatabase(VectorDatabase): """Initialize database connection and create tables""" try: # Create async engine for async operations - self.async_engine = create_async_engine( - self.async_connection_string, - echo=False, - pool_pre_ping=True - ) - self.AsyncSessionLocal = async_sessionmaker( - self.async_engine, - class_=AsyncSession, - expire_on_commit=False - ) + self.async_engine = create_async_engine(self.async_connection_string, echo=False, pool_pre_ping=True) + self.AsyncSessionLocal = async_sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False) # Create sync engine for table creation - sync_connection_string = self.connection_string.replace( - "postgresql+asyncpg://", "postgresql+psycopg://" - ) + sync_connection_string = self.connection_string.replace('postgresql+asyncpg://', 'postgresql+psycopg://') self.engine = create_engine(sync_connection_string, echo=False) # Create pgvector extension and tables with self.engine.connect() as conn: # Enable pgvector extension - conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector')) conn.commit() # Create tables Base.metadata.create_all(self.engine) - self.ap.logger.info(f"Connected to PostgreSQL with pgvector") + self.ap.logger.info('Connected to PostgreSQL with pgvector') except Exception as e: - self.ap.logger.error(f"Failed to connect to PostgreSQL: {e}") + self.ap.logger.error(f'Failed to connect to PostgreSQL: {e}') raise async def get_or_create_collection(self, collection: str): @@ -144,24 +129,20 @@ class PgVectorDatabase(VectorDatabase): id=vector_id, collection=collection, embedding=embeddings_list[i], - text=metadata.get("text", ""), - file_id=metadata.get("file_id", ""), - chunk_uuid=metadata.get("uuid", "") + text=metadata.get('text', ''), + file_id=metadata.get('file_id', ''), + chunk_uuid=metadata.get('uuid', ''), ) session.add(entry) await session.commit() - self.ap.logger.info( - f"Added {len(ids)} embeddings to pgvector collection '{collection}'" - ) + self.ap.logger.info(f"Added {len(ids)} embeddings to pgvector collection '{collection}'") except Exception as e: await session.rollback() - self.ap.logger.error(f"Error adding embeddings to pgvector: {e}") + self.ap.logger.error(f'Error adding embeddings to pgvector: {e}') raise - async def search( - self, collection: str, query_embedding: list[float], k: int = 5 - ) -> Dict[str, Any]: + async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]: """Search for similar vectors using cosine distance Args: @@ -177,7 +158,7 @@ class PgVectorDatabase(VectorDatabase): async with self.AsyncSessionLocal() as session: try: # Use cosine distance for similarity search - from sqlalchemy import select, func + from sqlalchemy import select # Query for similar vectors stmt = ( @@ -186,7 +167,7 @@ class PgVectorDatabase(VectorDatabase): PgVectorEntry.text, PgVectorEntry.file_id, PgVectorEntry.chunk_uuid, - PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance') + PgVectorEntry.embedding.cosine_distance(query_embedding).label('distance'), ) .filter(PgVectorEntry.collection == collection) .order_by(PgVectorEntry.embedding.cosine_distance(query_embedding)) @@ -204,25 +185,17 @@ class PgVectorDatabase(VectorDatabase): for row in rows: ids.append(row.id) distances.append(float(row.distance)) - metadatas.append({ - "text": row.text or "", - "file_id": row.file_id or "", - "uuid": row.chunk_uuid or "" - }) + metadatas.append( + {'text': row.text or '', 'file_id': row.file_id or '', 'uuid': row.chunk_uuid or ''} + ) - result_dict = { - "ids": [ids], - "distances": [distances], - "metadatas": [metadatas] - } + result_dict = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]} - self.ap.logger.info( - f"pgvector search in '{collection}' returned {len(ids)} results" - ) + self.ap.logger.info(f"pgvector search in '{collection}' returned {len(ids)} results") return result_dict except Exception as e: - self.ap.logger.error(f"Error searching pgvector: {e}") + self.ap.logger.error(f'Error searching pgvector: {e}') raise async def delete_by_file_id(self, collection: str, file_id: str) -> None: @@ -239,8 +212,7 @@ class PgVectorDatabase(VectorDatabase): from sqlalchemy import delete stmt = delete(PgVectorEntry).where( - PgVectorEntry.collection == collection, - PgVectorEntry.file_id == file_id + PgVectorEntry.collection == collection, PgVectorEntry.file_id == file_id ) await session.execute(stmt) await session.commit() @@ -250,7 +222,7 @@ class PgVectorDatabase(VectorDatabase): ) except Exception as e: await session.rollback() - self.ap.logger.error(f"Error deleting from pgvector: {e}") + self.ap.logger.error(f'Error deleting from pgvector: {e}') raise async def delete_collection(self, collection: str): @@ -266,16 +238,14 @@ class PgVectorDatabase(VectorDatabase): try: from sqlalchemy import delete - stmt = delete(PgVectorEntry).where( - PgVectorEntry.collection == collection - ) + stmt = delete(PgVectorEntry).where(PgVectorEntry.collection == collection) await session.execute(stmt) await session.commit() self.ap.logger.info(f"Deleted pgvector collection '{collection}'") except Exception as e: await session.rollback() - self.ap.logger.error(f"Error deleting pgvector collection: {e}") + self.ap.logger.error(f'Error deleting pgvector collection: {e}') raise async def close(self): diff --git a/src/langbot/pkg/vector/vdbs/seekdb.py b/src/langbot/pkg/vector/vdbs/seekdb.py index acb5e67d..b007f2fb 100644 --- a/src/langbot/pkg/vector/vdbs/seekdb.py +++ b/src/langbot/pkg/vector/vdbs/seekdb.py @@ -3,10 +3,8 @@ from __future__ import annotations import asyncio from typing import Any, Dict, List -import sqlalchemy from langbot.pkg.core import app -from langbot.pkg.entity.persistence import model as persistence_model from langbot.pkg.vector.vdb import VectorDatabase try: @@ -87,14 +85,16 @@ class SeekDBVectorDatabase(VectorDatabase): self._collections: Dict[str, Any] = {} self._collection_configs: Dict[str, HNSWConfiguration] = {} - self._escape_table = str.maketrans({ - '\x00': '', - '\\': '\\\\', - '"': '\\"', - '\n': '\\n', - '\r': '\\r', - '\t': '\\t', - }) + self._escape_table = str.maketrans( + { + '\x00': '', + '\\': '\\\\', + '"': '\\"', + '\n': '\\n', + '\r': '\\r', + '\t': '\\t', + } + ) async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None) -> Any: """Internal method to get or create a collection with proper configuration.""" @@ -133,8 +133,10 @@ class SeekDBVectorDatabase(VectorDatabase): def _clean_metadata(self, meta: Dict[str, Any]) -> Dict[str, Any]: """SeekDB metadata doesn't support \\ and ", insert will error 3104""" return { - k: v.translate(self._escape_table) if isinstance(v, str) - else v if v is None or isinstance(v, (int, float, bool)) + k: v.translate(self._escape_table) + if isinstance(v, str) + else v + if v is None or isinstance(v, (int, float, bool)) else str(v) for k, v in meta.items() if v is not None @@ -145,11 +147,7 @@ class SeekDBVectorDatabase(VectorDatabase): return await self._get_or_create_collection_internal(collection) async def add_embeddings( - self, - collection: str, - ids: List[str], - embeddings_list: List[List[float]], - metadatas: List[Dict[str, Any]] + self, collection: str, ids: List[str], embeddings_list: List[List[float]], metadatas: List[Dict[str, Any]] ) -> None: """Add vector embeddings to the specified collection. From fc6e414be4cee596055a1d79e5b8a6ff2882c035 Mon Sep 17 00:00:00 2001 From: "Junyan Qin (Chin)" Date: Fri, 23 Jan 2026 13:43:12 +0800 Subject: [PATCH 6/6] feat: add GitHub Actions workflow for linting with Ruff (#1929) * feat: add GitHub Actions workflow for linting with Ruff * refactor: rename lint job and add formatting step to Ruff workflow * chore: run ruff format * chore: rename Ruff lint job to 'Lint' and add frontend linting workflow --- .github/workflows/lint.yml | 60 +++++ src/langbot/libs/qq_official_api/api.py | 12 +- .../libs/wecom_ai_bot_api/wecombotevent.py | 9 +- .../api/http/controller/groups/webhooks.py | 3 - src/langbot/pkg/api/http/service/bot.py | 11 +- src/langbot/pkg/core/stages/build_app.py | 1 - src/langbot/pkg/platform/sources/lark.py | 12 +- src/langbot/pkg/platform/sources/wecom.py | 38 +-- .../modelmgr/requesters/seekdbembed.py | 3 +- .../pkg/provider/runners/localagent.py | 8 +- src/langbot/pkg/provider/runners/n8nsvapi.py | 97 ++++--- src/langbot/pkg/provider/tools/loaders/mcp.py | 17 +- .../pkg/rag/knowledge/services/embedder.py | 4 +- src/langbot/pkg/vector/mgr.py | 7 +- src/langbot/pkg/vector/vdbs/milvus.py | 147 +++++------ tests/unit_tests/config/test_env_override.py | 237 +++++++----------- .../plugin/test_plugin_component_filtering.py | 111 ++------ 17 files changed, 327 insertions(+), 450 deletions(-) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..e1d89c1e --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,60 @@ +name: Lint + +on: + push: + branches: + - main + - master + - dev + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + +jobs: + ruff: + name: Ruff Lint & Format + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install uv + uses: astral-sh/setup-uv@v4 + + - name: Install dependencies + run: uv sync --dev + + - name: Run ruff check + run: uv run ruff check src + + - name: Run ruff format + run: uv run ruff format src --check + + frontend: + name: Frontend Lint + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '25' + + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + version: 9 + + - name: Install dependencies + working-directory: web + run: pnpm install + + - name: Run lint + working-directory: web + run: pnpm lint diff --git a/src/langbot/libs/qq_official_api/api.py b/src/langbot/libs/qq_official_api/api.py index e4d4e468..51a56d53 100644 --- a/src/langbot/libs/qq_official_api/api.py +++ b/src/langbot/libs/qq_official_api/api.py @@ -85,7 +85,6 @@ class QQOfficialClient: req: Quart Request 对象 """ try: - body = await req.get_data() print(f'[QQ Official] Received request, body length: {len(body)}') @@ -96,7 +95,6 @@ class QQOfficialClient: payload = json.loads(body) - if payload.get('op') == 13: validation_data = payload.get('d') if not validation_data: @@ -276,21 +274,21 @@ class QQOfficialClient: seed = bot_secret while len(seed) < target_size: seed *= 2 - return seed[:target_size].encode("utf-8") + return seed[:target_size].encode('utf-8') async def verify(self, validation_payload: dict): seed = await self.repeat_seed(self.secret) private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed) - event_ts = validation_payload.get("event_ts", "") - plain_token = validation_payload.get("plain_token", "") + event_ts = validation_payload.get('event_ts', '') + plain_token = validation_payload.get('plain_token', '') msg = event_ts + plain_token # sign signature = private_key.sign(msg.encode()).hex() response = { - "plain_token": plain_token, - "signature": signature, + 'plain_token': plain_token, + 'signature': signature, } return response diff --git a/src/langbot/libs/wecom_ai_bot_api/wecombotevent.py b/src/langbot/libs/wecom_ai_bot_api/wecombotevent.py index 75c6bbde..bc105cf8 100644 --- a/src/langbot/libs/wecom_ai_bot_api/wecombotevent.py +++ b/src/langbot/libs/wecom_ai_bot_api/wecombotevent.py @@ -36,7 +36,12 @@ class WecomBotEvent(dict): """ 用户名称 """ - return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid + return ( + self.get('username', '') + or self.get('from', {}).get('alias', '') + or self.get('from', {}).get('name', '') + or self.userid + ) @property def chatname(self) -> str: @@ -121,7 +126,7 @@ class WecomBotEvent(dict): 消息id """ return self.get('msgid', '') - + @property def ai_bot_id(self) -> str: """ diff --git a/src/langbot/pkg/api/http/controller/groups/webhooks.py b/src/langbot/pkg/api/http/controller/groups/webhooks.py index 0964076f..ec46c744 100644 --- a/src/langbot/pkg/api/http/controller/groups/webhooks.py +++ b/src/langbot/pkg/api/http/controller/groups/webhooks.py @@ -30,7 +30,6 @@ class WebhookRouterGroup(group.RouterGroup): 适配器返回的响应 """ try: - runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid) if not runtime_bot: @@ -39,11 +38,9 @@ class WebhookRouterGroup(group.RouterGroup): if not runtime_bot.enable: return quart.jsonify({'error': 'Bot is disabled'}), 403 - if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'): return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501 - response = await runtime_bot.adapter.handle_unified_webhook( bot_uuid=bot_uuid, path=path, diff --git a/src/langbot/pkg/api/http/service/bot.py b/src/langbot/pkg/api/http/service/bot.py index ac7ec13a..0632935b 100644 --- a/src/langbot/pkg/api/http/service/bot.py +++ b/src/langbot/pkg/api/http/service/bot.py @@ -59,7 +59,16 @@ class BotService: adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id # Webhook URL for unified webhook adapters (independent of bot running state) - if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack', 'wecomcs', 'LINE', 'lark']: + if persistence_bot['adapter'] in [ + 'wecom', + 'wecombot', + 'officialaccount', + 'qqofficial', + 'slack', + 'wecomcs', + 'LINE', + 'lark', + ]: webhook_prefix = self.ap.instance_config.data['api'].get('webhook_prefix', 'http://127.0.0.1:5300') webhook_url = f'/bots/{bot_uuid}' adapter_runtime_values['webhook_url'] = webhook_url diff --git a/src/langbot/pkg/core/stages/build_app.py b/src/langbot/pkg/core/stages/build_app.py index e7226041..791b5a9e 100644 --- a/src/langbot/pkg/core/stages/build_app.py +++ b/src/langbot/pkg/core/stages/build_app.py @@ -34,7 +34,6 @@ from .. import taskmgr from ...telemetry import telemetry as telemetry_module - @stage.stage_class('BuildAppStage') class BuildAppStage(stage.BootingStage): """Build LangBot application""" diff --git a/src/langbot/pkg/platform/sources/lark.py b/src/langbot/pkg/platform/sources/lark.py index 3c13c019..f123889c 100644 --- a/src/langbot/pkg/platform/sources/lark.py +++ b/src/langbot/pkg/platform/sources/lark.py @@ -244,7 +244,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time)) - if message.message_type == 'text': element_list = [] @@ -310,7 +309,11 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): ] elif message.message_type == 'audio': message_content['content'] = [ - {'tag': 'audio', 'file_key': message_content['file_key'], "duration": message_content.get('duration',0)} + { + 'tag': 'audio', + 'file_key': message_content['file_key'], + 'duration': message_content.get('duration', 0), + } ] for ele in message_content['content']: @@ -367,12 +370,9 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): audio_bytes = response.file.read() audio_base64 = base64.b64encode(audio_bytes).decode() - # Get content type from response headers content_type = response.raw.headers.get('content-type', 'audio/mpeg') - - mime_main = content_type.split(';')[0].strip() ext = mimetypes.guess_extension(mime_main) or '.bin' temp_dir = tempfile.gettempdir() @@ -418,7 +418,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): file_bytes = response.file.read() file_base64 = base64.b64encode(file_bytes).decode() - file_format = response.raw.headers['content-type'] file_size = len(file_bytes) @@ -453,7 +452,6 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter): ) ) - return platform_message.MessageChain(lb_msg_list) diff --git a/src/langbot/pkg/platform/sources/wecom.py b/src/langbot/pkg/platform/sources/wecom.py index dc14a9b7..7bed676f 100644 --- a/src/langbot/pkg/platform/sources/wecom.py +++ b/src/langbot/pkg/platform/sources/wecom.py @@ -18,52 +18,52 @@ import langbot_plugin.api.entities.builtin.platform.entities as platform_entitie def split_string_by_bytes(text, limit=2048, encoding='utf-8'): """ Splits a string into a list of strings, where each part is at most 'limit' bytes. - + Args: text (str): The original string to split. limit (int): The maximum byte size for each split part. encoding (str): The encoding to use (default is 'utf-8'). - + Returns: list: A list of split strings. """ # 1. Encode the entire string into bytes bytes_data = text.encode(encoding) total_len = len(bytes_data) - + parts = [] start = 0 - + while start < total_len: # 2. Determine the end index for the current chunk # It shouldn't exceed the total length end = min(start + limit, total_len) - + # 3. Slice the byte array chunk = bytes_data[start:end] - + # 4. Attempt to decode the chunk # Use errors='ignore' to drop any partial bytes at the end of the chunk # (e.g., if a 3-byte character was cut after the 2nd byte) part_str = chunk.decode(encoding, errors='ignore') - + # 5. Calculate the actual byte length of the successfully decoded string # This tells us exactly where the valid character boundary ended part_bytes = part_str.encode(encoding) part_len = len(part_bytes) - + # Safety check: Prevent infinite loop if limit is too small (e.g., limit=1 for a Chinese char) if part_len == 0 and end < total_len: # Force advance by 1 byte to consume the un-decodable byte or raise error # Here we just treat it as a part to avoid stuck loops, though it might be invalid - start += 1 + start += 1 continue parts.append(part_str) - + # 6. Move the start pointer by the actual length consumed start += part_len - + return parts @@ -75,13 +75,15 @@ class WecomMessageConverter(abstract_platform_adapter.AbstractMessageConverter): for msg in message_chain: if type(msg) is platform_message.Plain: chunks = split_string_by_bytes(msg.text) - content_list.extend([ - { - 'type': 'text', - 'content': chunk, - } - for chunk in chunks - ]) + content_list.extend( + [ + { + 'type': 'text', + 'content': chunk, + } + for chunk in chunks + ] + ) elif type(msg) is platform_message.Image: content_list.append( { diff --git a/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.py b/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.py index a181d51c..7fd98d69 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/seekdbembed.py @@ -51,9 +51,10 @@ class SeekDBEmbedding(requester.ProviderAPIRequester): await self.initialize() if self._embedding_function is None: - raise RuntimeError("SeekDB embedding function initialization failed") + raise RuntimeError('SeekDB embedding function initialization failed') return self._embedding_function(input_text) except Exception as e: from .. import errors + raise errors.RequesterError(f'SeekDB embedding failed: {str(e)}') diff --git a/src/langbot/pkg/provider/runners/localagent.py b/src/langbot/pkg/provider/runners/localagent.py index 5e5be37c..64dbfc63 100644 --- a/src/langbot/pkg/provider/runners/localagent.py +++ b/src/langbot/pkg/provider/runners/localagent.py @@ -218,10 +218,14 @@ class LocalAgentRunner(runner.RequestRunner): parameters = {} func_ret = await self.ap.tool_mgr.execute_func_call(func.name, parameters, query=query) - + # Handle return value content tool_content = None - if isinstance(func_ret, list) and len(func_ret) > 0 and isinstance(func_ret[0], provider_message.ContentElement): + if ( + isinstance(func_ret, list) + and len(func_ret) > 0 + and isinstance(func_ret[0], provider_message.ContentElement) + ): tool_content = func_ret else: tool_content = json.dumps(func_ret, ensure_ascii=False) diff --git a/src/langbot/pkg/provider/runners/n8nsvapi.py b/src/langbot/pkg/provider/runners/n8nsvapi.py index 89cb6679..d7ec3ccb 100644 --- a/src/langbot/pkg/provider/runners/n8nsvapi.py +++ b/src/langbot/pkg/provider/runners/n8nsvapi.py @@ -68,15 +68,16 @@ class N8nServiceAPIRunner(runner.RequestRunner): return plain_text - async def _process_stream_response(self, response: aiohttp.ClientResponse) -> typing.AsyncGenerator[ - provider_message.Message, None]: + async def _process_stream_response( + self, response: aiohttp.ClientResponse + ) -> typing.AsyncGenerator[provider_message.Message, None]: """处理流式响应——支持部分 JSON 和多个 JSON 对象在同一 chunk 的情况""" - full_content = "" + full_content = '' chunk_idx = 0 is_final = False message_idx = 0 - buffer = "" + buffer = '' decoder = json.JSONDecoder() async for raw_chunk in response.content.iter_chunked(1024): @@ -129,7 +130,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): preview = chunk_str[:200] except Exception: preview = '' - self.ap.logger.warning(f"Failed to process chunk: {e}; chunk preview: {preview}") + self.ap.logger.warning(f'Failed to process chunk: {e}; chunk preview: {preview}') # 流结束后,尝试解析残余 buffer if buffer: @@ -151,7 +152,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): ) except Exception as e: preview = buffer[:200] - self.ap.logger.warning(f"Failed to parse remaining buffer: {e}; buffer preview: {preview}") + self.ap.logger.warning(f'Failed to parse remaining buffer: {e}; buffer preview: {preview}') async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: """调用n8n webhook""" @@ -165,7 +166,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): # 准备请求数据 payload = { # 基本消息内容 - 'chatInput' :plain_text, # 考虑到之前用户直接用的message model这里添加新键 + 'chatInput': plain_text, # 考虑到之前用户直接用的message model这里添加新键 'message': plain_text, 'user_message_text': plain_text, 'conversation_id': query.session.using_conversation.uuid, @@ -217,57 +218,49 @@ class N8nServiceAPIRunner(runner.RequestRunner): # 调用webhook async with aiohttp.ClientSession() as session: - if is_stream: - # 流式请求 - async with session.post( - self.webhook_url, - json=payload, - headers=headers, - auth=auth, - timeout=self.timeout - ) as response: + if is_stream: + # 流式请求 + async with session.post( + self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout + ) as response: + if response.status != 200: + error_text = await response.text() + self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') + raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') + + # 处理流式响应 + async for chunk in self._process_stream_response(response): + yield chunk + else: + async with session.post( + self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout + ) as response: + try: + async for chunk in self._process_stream_response(response): + output_content = chunk.content if chunk.is_final else '' + except: + # 非流式请求(保持原有逻辑) if response.status != 200: error_text = await response.text() self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') - # 处理流式响应 - async for chunk in self._process_stream_response(response): - yield chunk - else: - async with session.post( - self.webhook_url, - json=payload, - headers=headers, - auth=auth, - timeout=self.timeout - ) as response: - try: - async for chunk in self._process_stream_response(response): - output_content = chunk.content if chunk.is_final else '' - except: - # 非流式请求(保持原有逻辑) - if response.status != 200: - error_text = await response.text() - self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') - raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') + # 解析响应 + response_data = await response.json() + self.ap.logger.debug(f'n8n webhook response: {response_data}') - # 解析响应 - response_data = await response.json() - self.ap.logger.debug(f'n8n webhook response: {response_data}') + # 从响应中提取输出 + if self.output_key in response_data: + output_content = response_data[self.output_key] + else: + # 如果没有指定的输出键,则使用整个响应 + output_content = json.dumps(response_data, ensure_ascii=False) - # 从响应中提取输出 - if self.output_key in response_data: - output_content = response_data[self.output_key] - else: - # 如果没有指定的输出键,则使用整个响应 - output_content = json.dumps(response_data, ensure_ascii=False) - - # 返回消息 - yield provider_message.Message( - role='assistant', - content=output_content, - ) + # 返回消息 + yield provider_message.Message( + role='assistant', + content=output_content, + ) except Exception as e: self.ap.logger.error(f'n8n webhook call exception: {str(e)}') raise N8nAPIError(f'n8n webhook call exception: {str(e)}') @@ -275,4 +268,4 @@ class N8nServiceAPIRunner(runner.RequestRunner): async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]: """运行请求""" async for msg in self._call_webhook(query): - yield msg \ No newline at end of file + yield msg diff --git a/src/langbot/pkg/provider/tools/loaders/mcp.py b/src/langbot/pkg/provider/tools/loaders/mcp.py index 4b0583c6..46d63b84 100644 --- a/src/langbot/pkg/provider/tools/loaders/mcp.py +++ b/src/langbot/pkg/provider/tools/loaders/mcp.py @@ -194,7 +194,7 @@ class RuntimeMCPSession: async def func(*, _tool=tool, **kwargs): if not self.session: - raise Exception("MCP session is not connected") + raise Exception('MCP session is not connected') result = await self.session.call_tool(_tool.name, kwargs) if result.isError: @@ -202,8 +202,8 @@ class RuntimeMCPSession: for content in result.content: if content.type == 'text': error_texts.append(content.text) - raise Exception("\n".join(error_texts) if error_texts else "Unknown error from MCP tool") - + raise Exception('\n'.join(error_texts) if error_texts else 'Unknown error from MCP tool') + result_contents: list[provider_message.ContentElement] = [] for content in result.content: if content.type == 'text': @@ -213,7 +213,7 @@ class RuntimeMCPSession: elif content.type == 'resource': # TODO: Handle resource content pass - + return result_contents func.__name__ = tool.name @@ -221,8 +221,8 @@ class RuntimeMCPSession: self.functions.append( resource_tool.LLMTool( name=tool.name, - human_desc=tool.description or "", - description=tool.description or "", + human_desc=tool.description or '', + description=tool.description or '', parameters=tool.inputSchema, func=func, ) @@ -338,13 +338,10 @@ class MCPLoader(loader.ToolLoader): """ uuid_ = server_config.get('uuid') if not uuid_: - self.ap.logger.warning( - 'Server UUID is None for MCP server, maybe testing in the config page.' - ) + self.ap.logger.warning('Server UUID is None for MCP server, maybe testing in the config page.') uuid_ = str(uuid_module.uuid4()) server_config['uuid'] = uuid_ - name = server_config['name'] uuid = server_config['uuid'] mode = server_config['mode'] diff --git a/src/langbot/pkg/rag/knowledge/services/embedder.py b/src/langbot/pkg/rag/knowledge/services/embedder.py index f93382ff..485bc21e 100644 --- a/src/langbot/pkg/rag/knowledge/services/embedder.py +++ b/src/langbot/pkg/rag/knowledge/services/embedder.py @@ -35,9 +35,9 @@ class Embedder(BaseService): # get embeddings (batch size limit: 64 for OpenAI) MAX_BATCH_SIZE = 64 embeddings_list: list[list[float]] = [] - + for i in range(0, len(chunks), MAX_BATCH_SIZE): - batch = chunks[i:i + MAX_BATCH_SIZE] + batch = chunks[i : i + MAX_BATCH_SIZE] batch_embeddings = await embedding_model.provider.requester.invoke_embedding( model=embedding_model, input_text=batch, diff --git a/src/langbot/pkg/vector/mgr.py b/src/langbot/pkg/vector/mgr.py index f95f5f75..f0cb742c 100644 --- a/src/langbot/pkg/vector/mgr.py +++ b/src/langbot/pkg/vector/mgr.py @@ -55,12 +55,7 @@ class VectorDBManager: user = pgvector_config.get('user', 'postgres') password = pgvector_config.get('password', 'postgres') self.vector_db = PgVectorDatabase( - self.ap, - host=host, - port=port, - database=database, - user=user, - password=password + self.ap, host=host, port=port, database=database, user=user, password=password ) self.ap.logger.info('Initialized pgvector database backend.') diff --git a/src/langbot/pkg/vector/vdbs/milvus.py b/src/langbot/pkg/vector/vdbs/milvus.py index f15071c4..2852dea1 100644 --- a/src/langbot/pkg/vector/vdbs/milvus.py +++ b/src/langbot/pkg/vector/vdbs/milvus.py @@ -10,7 +10,7 @@ from langbot.pkg.core import app class MilvusVectorDatabase(VectorDatabase): """Milvus vector database implementation""" - def __init__(self, ap: app.Application, uri: str = "milvus.db", token: str = None, db_name: str = None): + def __init__(self, ap: app.Application, uri: str = 'milvus.db', token: str = None, db_name: str = None): """Initialize Milvus vector database Args: @@ -34,32 +34,32 @@ class MilvusVectorDatabase(VectorDatabase): self.client = MilvusClient(uri=self.uri, token=self.token, db_name=self.db_name) else: self.client = MilvusClient(uri=self.uri, db_name=self.db_name) - self.ap.logger.info(f"Connected to Milvus at {self.uri}") + self.ap.logger.info(f'Connected to Milvus at {self.uri}') except Exception as e: - self.ap.logger.error(f"Failed to connect to Milvus: {e}") + self.ap.logger.error(f'Failed to connect to Milvus: {e}') raise @staticmethod def _normalize_collection_name(collection: str) -> str: """Normalize collection name to comply with Milvus naming requirements. - + Milvus requirements: - First character must be an underscore or letter - Can only contain numbers, letters and underscores - + Args: collection: Original collection name (e.g., UUID with hyphens) - + Returns: Normalized collection name that complies with Milvus requirements """ # Replace hyphens with underscores normalized = collection.replace('-', '_') - + # If first character is not a letter or underscore, prepend 'kb_' if normalized and not (normalized[0].isalpha() or normalized[0] == '_'): normalized = 'kb_' + normalized - + return normalized async def _ensure_vector_index(self, collection: str) -> None: @@ -70,15 +70,11 @@ class MilvusVectorDatabase(VectorDatabase): """ index_params = IndexParams() index_params.add_index( - field_name="vector", - index_type="AUTOINDEX", - metric_type="COSINE", - ) - await asyncio.to_thread( - self.client.create_index, - collection_name=collection, - index_params=index_params + field_name='vector', + index_type='AUTOINDEX', + metric_type='COSINE', ) + await asyncio.to_thread(self.client.create_index, collection_name=collection, index_params=index_params) async def _get_or_create_collection_internal(self, collection: str, vector_size: int = None): """Internal method to get or create a Milvus collection with proper configuration. @@ -89,14 +85,12 @@ class MilvusVectorDatabase(VectorDatabase): """ # Normalize collection name for Milvus compatibility collection = self._normalize_collection_name(collection) - + if collection in self._collections: return collection # Check if collection exists - has_collection = await asyncio.to_thread( - self.client.has_collection, collection_name=collection - ) + has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection) if not has_collection: # Default dimension if not specified (for backward compatibility) @@ -104,24 +98,26 @@ class MilvusVectorDatabase(VectorDatabase): vector_size = 1536 fields = [ - FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=255), - FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=vector_size), - FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), - FieldSchema(name="file_id", dtype=DataType.VARCHAR, max_length=255), - FieldSchema(name="chunk_uuid", dtype=DataType.VARCHAR, max_length=255), + FieldSchema(name='id', dtype=DataType.VARCHAR, is_primary=True, max_length=255), + FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=vector_size), + FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535), + FieldSchema(name='file_id', dtype=DataType.VARCHAR, max_length=255), + FieldSchema(name='chunk_uuid', dtype=DataType.VARCHAR, max_length=255), ] - schema = CollectionSchema(fields=fields, description="LangBot knowledge base vectors") + schema = CollectionSchema(fields=fields, description='LangBot knowledge base vectors') await asyncio.to_thread( self.client.create_collection, collection_name=collection, schema=schema, - metric_type="COSINE", + metric_type='COSINE', ) await self._ensure_vector_index(collection) - self.ap.logger.info(f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX") + self.ap.logger.info( + f"Created Milvus collection '{collection}' with dimension={vector_size}, index=AUTOINDEX" + ) else: # Ensure index exists for existing collection await self._ensure_index_if_missing(collection) @@ -137,11 +133,8 @@ class MilvusVectorDatabase(VectorDatabase): collection: Normalized collection name """ try: - indexes = await asyncio.to_thread( - self.client.list_indexes, - collection_name=collection - ) - if "vector" not in indexes: + indexes = await asyncio.to_thread(self.client.list_indexes, collection_name=collection) + if 'vector' not in indexes: await self._ensure_vector_index(collection) self.ap.logger.info(f"Created index for existing Milvus collection '{collection}'") except Exception as e: @@ -172,7 +165,7 @@ class MilvusVectorDatabase(VectorDatabase): metadatas: List of metadata dictionaries for each vector """ collection = self._normalize_collection_name(collection) - + if not embeddings_list: return @@ -184,39 +177,30 @@ class MilvusVectorDatabase(VectorDatabase): data = [] for i, vector_id in enumerate(ids): entry = { - "id": vector_id, - "vector": embeddings_list[i], + 'id': vector_id, + 'vector': embeddings_list[i], } # Add metadata fields if metadatas and i < len(metadatas): metadata = metadatas[i] # Add common metadata fields - if "text" in metadata: - entry["text"] = metadata["text"] - if "file_id" in metadata: - entry["file_id"] = metadata["file_id"] - if "uuid" in metadata: - entry["chunk_uuid"] = metadata["uuid"] + if 'text' in metadata: + entry['text'] = metadata['text'] + if 'file_id' in metadata: + entry['file_id'] = metadata['file_id'] + if 'uuid' in metadata: + entry['chunk_uuid'] = metadata['uuid'] data.append(entry) # Insert data into Milvus - await asyncio.to_thread( - self.client.insert, - collection_name=collection, - data=data - ) + await asyncio.to_thread(self.client.insert, collection_name=collection, data=data) # Load collection for searching (Milvus requires this) - await asyncio.to_thread( - self.client.load_collection, - collection_name=collection - ) + await asyncio.to_thread(self.client.load_collection, collection_name=collection) self.ap.logger.info(f"Added {len(ids)} embeddings to Milvus collection '{collection}'") - async def search( - self, collection: str, query_embedding: list[float], k: int = 5 - ) -> Dict[str, Any]: + async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> Dict[str, Any]: """Search for similar vectors in Milvus collection Args: @@ -231,10 +215,7 @@ class MilvusVectorDatabase(VectorDatabase): await self.get_or_create_collection(collection) # Perform search - search_params = { - "metric_type": "COSINE", - "params": {} - } + search_params = {'metric_type': 'COSINE', 'params': {}} results = await asyncio.to_thread( self.client.search, @@ -242,7 +223,7 @@ class MilvusVectorDatabase(VectorDatabase): data=[query_embedding], limit=k, search_params=search_params, - output_fields=["text", "file_id", "chunk_uuid"] + output_fields=['text', 'file_id', 'chunk_uuid'], ) # Convert results to Chroma-compatible format @@ -253,30 +234,24 @@ class MilvusVectorDatabase(VectorDatabase): if results and len(results) > 0: for hit in results[0]: - ids.append(hit.get("id", "")) - distances.append(hit.get("distance", 0.0)) + ids.append(hit.get('id', '')) + distances.append(hit.get('distance', 0.0)) # Build metadata from entity fields - entity = hit.get("entity", {}) + entity = hit.get('entity', {}) metadata = {} - if "text" in entity: - metadata["text"] = entity["text"] - if "file_id" in entity: - metadata["file_id"] = entity["file_id"] - if "chunk_uuid" in entity: - metadata["uuid"] = entity["chunk_uuid"] + if 'text' in entity: + metadata['text'] = entity['text'] + if 'file_id' in entity: + metadata['file_id'] = entity['file_id'] + if 'chunk_uuid' in entity: + metadata['uuid'] = entity['chunk_uuid'] metadatas.append(metadata) # Return in Chroma-compatible format (nested lists) - result = { - "ids": [ids], - "distances": [distances], - "metadatas": [metadatas] - } + result = {'ids': [ids], 'distances': [distances], 'metadatas': [metadatas]} - self.ap.logger.info( - f"Milvus search in '{collection}' returned {len(ids)} results" - ) + self.ap.logger.info(f"Milvus search in '{collection}' returned {len(ids)} results") return result async def delete_by_file_id(self, collection: str, file_id: str) -> None: @@ -290,14 +265,8 @@ class MilvusVectorDatabase(VectorDatabase): await self.get_or_create_collection(collection) # Delete entities matching the file_id - await asyncio.to_thread( - self.client.delete, - collection_name=collection, - filter=f'file_id == "{file_id}"' - ) - self.ap.logger.info( - f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}" - ) + await asyncio.to_thread(self.client.delete, collection_name=collection, filter=f'file_id == "{file_id}"') + self.ap.logger.info(f"Deleted embeddings from Milvus collection '{collection}' with file_id: {file_id}") async def delete_collection(self, collection: str): """Delete a Milvus collection @@ -306,18 +275,14 @@ class MilvusVectorDatabase(VectorDatabase): collection: Collection name to delete """ collection = self._normalize_collection_name(collection) - + self._collections.discard(collection) # Check if collection exists before attempting deletion - has_collection = await asyncio.to_thread( - self.client.has_collection, collection_name=collection - ) + has_collection = await asyncio.to_thread(self.client.has_collection, collection_name=collection) if has_collection: - await asyncio.to_thread( - self.client.drop_collection, collection_name=collection - ) + await asyncio.to_thread(self.client.drop_collection, collection_name=collection) self.ap.logger.info(f"Deleted Milvus collection '{collection}'") else: self.ap.logger.warning(f"Milvus collection '{collection}' not found") diff --git a/tests/unit_tests/config/test_env_override.py b/tests/unit_tests/config/test_env_override.py index d20988e9..0e309d4c 100644 --- a/tests/unit_tests/config/test_env_override.py +++ b/tests/unit_tests/config/test_env_override.py @@ -9,27 +9,28 @@ from typing import Any def _apply_env_overrides_to_config(cfg: dict) -> dict: """Apply environment variable overrides to data/config.yaml - - Environment variables should be uppercase and use __ (double underscore) + + Environment variables should be uppercase and use __ (double underscore) to represent nested keys. For example: - CONCURRENCY__PIPELINE overrides concurrency.pipeline - PLUGIN__RUNTIME_WS_URL overrides plugin.runtime_ws_url - + Arrays and dict types are ignored. - + Args: cfg: Configuration dictionary - + Returns: Updated configuration dictionary """ + def convert_value(value: str, original_value: Any) -> Any: """Convert string value to appropriate type based on original value - + Args: value: String value from environment variable original_value: Original value to infer type from - + Returns: Converted value (falls back to string if conversion fails) """ @@ -49,7 +50,7 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict: return value else: return value - + # Process environment variables for env_key, env_value in os.environ.items(): # Check if the environment variable is uppercase and contains __ @@ -57,18 +58,18 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict: continue if '__' not in env_key: continue - + # Convert environment variable name to config path # e.g., CONCURRENCY__PIPELINE -> ['concurrency', 'pipeline'] keys = [key.lower() for key in env_key.split('__')] - + # Navigate to the target value and validate the path current = cfg - + for i, key in enumerate(keys): if not isinstance(current, dict) or key not in current: break - + if i == len(keys) - 1: # At the final key - check if it's a scalar value if isinstance(current[key], (dict, list)): @@ -81,248 +82,182 @@ def _apply_env_overrides_to_config(cfg: dict) -> dict: else: # Navigate deeper current = current[key] - + return cfg class TestEnvOverrides: """Test environment variable override functionality""" - + def test_simple_string_override(self): """Test overriding a simple string value""" - cfg = { - 'api': { - 'port': 5300 - } - } - + cfg = {'api': {'port': 5300}} + # Set environment variable os.environ['API__PORT'] = '8080' - + result = _apply_env_overrides_to_config(cfg) - + assert result['api']['port'] == 8080 - + # Cleanup del os.environ['API__PORT'] - + def test_nested_key_override(self): """Test overriding nested keys with __ delimiter""" - cfg = { - 'concurrency': { - 'pipeline': 20, - 'session': 1 - } - } - + cfg = {'concurrency': {'pipeline': 20, 'session': 1}} + os.environ['CONCURRENCY__PIPELINE'] = '50' - + result = _apply_env_overrides_to_config(cfg) - + assert result['concurrency']['pipeline'] == 50 assert result['concurrency']['session'] == 1 # Unchanged - + del os.environ['CONCURRENCY__PIPELINE'] - + def test_deep_nested_override(self): """Test overriding deeply nested keys""" - cfg = { - 'system': { - 'jwt': { - 'expire': 604800, - 'secret': '' - } - } - } - + cfg = {'system': {'jwt': {'expire': 604800, 'secret': ''}}} + os.environ['SYSTEM__JWT__EXPIRE'] = '86400' os.environ['SYSTEM__JWT__SECRET'] = 'my_secret_key' - + result = _apply_env_overrides_to_config(cfg) - + assert result['system']['jwt']['expire'] == 86400 assert result['system']['jwt']['secret'] == 'my_secret_key' - + del os.environ['SYSTEM__JWT__EXPIRE'] del os.environ['SYSTEM__JWT__SECRET'] - + def test_underscore_in_key(self): """Test keys with underscores like runtime_ws_url""" - cfg = { - 'plugin': { - 'enable': True, - 'runtime_ws_url': 'ws://localhost:5400/control/ws' - } - } - + cfg = {'plugin': {'enable': True, 'runtime_ws_url': 'ws://localhost:5400/control/ws'}} + os.environ['PLUGIN__RUNTIME_WS_URL'] = 'ws://newhost:6000/ws' - + result = _apply_env_overrides_to_config(cfg) - + assert result['plugin']['runtime_ws_url'] == 'ws://newhost:6000/ws' - + del os.environ['PLUGIN__RUNTIME_WS_URL'] - + def test_boolean_conversion(self): """Test boolean value conversion""" - cfg = { - 'plugin': { - 'enable': True, - 'enable_marketplace': False - } - } - + cfg = {'plugin': {'enable': True, 'enable_marketplace': False}} + os.environ['PLUGIN__ENABLE'] = 'false' os.environ['PLUGIN__ENABLE_MARKETPLACE'] = 'true' - + result = _apply_env_overrides_to_config(cfg) - + assert result['plugin']['enable'] is False assert result['plugin']['enable_marketplace'] is True - + del os.environ['PLUGIN__ENABLE'] del os.environ['PLUGIN__ENABLE_MARKETPLACE'] - + def test_ignore_dict_type(self): """Test that dict types are ignored""" - cfg = { - 'database': { - 'use': 'sqlite', - 'sqlite': { - 'path': 'data/langbot.db' - } - } - } - + cfg = {'database': {'use': 'sqlite', 'sqlite': {'path': 'data/langbot.db'}}} + # Try to override a dict value - should be ignored os.environ['DATABASE__SQLITE'] = 'new_value' - + result = _apply_env_overrides_to_config(cfg) - + # Should remain a dict, not overridden assert isinstance(result['database']['sqlite'], dict) assert result['database']['sqlite']['path'] == 'data/langbot.db' - + del os.environ['DATABASE__SQLITE'] - + def test_ignore_list_type(self): """Test that list/array types are ignored""" - cfg = { - 'admins': ['admin1', 'admin2'], - 'command': { - 'enable': True, - 'prefix': ['!', '!'] - } - } - + cfg = {'admins': ['admin1', 'admin2'], 'command': {'enable': True, 'prefix': ['!', '!']}} + # Try to override list values - should be ignored os.environ['ADMINS'] = 'admin3' os.environ['COMMAND__PREFIX'] = '?' - + result = _apply_env_overrides_to_config(cfg) - + # Should remain lists, not overridden assert isinstance(result['admins'], list) assert result['admins'] == ['admin1', 'admin2'] assert isinstance(result['command']['prefix'], list) assert result['command']['prefix'] == ['!', '!'] - + del os.environ['ADMINS'] del os.environ['COMMAND__PREFIX'] - + def test_lowercase_env_var_ignored(self): """Test that lowercase environment variables are ignored""" - cfg = { - 'api': { - 'port': 5300 - } - } - + cfg = {'api': {'port': 5300}} + os.environ['api__port'] = '8080' - + result = _apply_env_overrides_to_config(cfg) - + # Should not be overridden assert result['api']['port'] == 5300 - + del os.environ['api__port'] - + def test_no_double_underscore_ignored(self): """Test that env vars without __ are ignored""" - cfg = { - 'api': { - 'port': 5300 - } - } - + cfg = {'api': {'port': 5300}} + os.environ['APIPORT'] = '8080' - + result = _apply_env_overrides_to_config(cfg) - + # Should not be overridden assert result['api']['port'] == 5300 - + del os.environ['APIPORT'] - + def test_nonexistent_key_ignored(self): """Test that env vars for non-existent keys are ignored""" - cfg = { - 'api': { - 'port': 5300 - } - } - + cfg = {'api': {'port': 5300}} + os.environ['API__NONEXISTENT'] = 'value' - + result = _apply_env_overrides_to_config(cfg) - + # Should not create new key assert 'nonexistent' not in result['api'] - + del os.environ['API__NONEXISTENT'] - + def test_integer_conversion(self): """Test integer value conversion""" - cfg = { - 'concurrency': { - 'pipeline': 20 - } - } - + cfg = {'concurrency': {'pipeline': 20}} + os.environ['CONCURRENCY__PIPELINE'] = '100' - + result = _apply_env_overrides_to_config(cfg) - + assert result['concurrency']['pipeline'] == 100 assert isinstance(result['concurrency']['pipeline'], int) - + del os.environ['CONCURRENCY__PIPELINE'] - + def test_multiple_overrides(self): """Test multiple environment variable overrides at once""" - cfg = { - 'api': { - 'port': 5300 - }, - 'concurrency': { - 'pipeline': 20, - 'session': 1 - }, - 'plugin': { - 'enable': False - } - } - + cfg = {'api': {'port': 5300}, 'concurrency': {'pipeline': 20, 'session': 1}, 'plugin': {'enable': False}} + os.environ['API__PORT'] = '8080' os.environ['CONCURRENCY__PIPELINE'] = '50' os.environ['PLUGIN__ENABLE'] = 'true' - + result = _apply_env_overrides_to_config(cfg) - + assert result['api']['port'] == 8080 assert result['concurrency']['pipeline'] == 50 assert result['plugin']['enable'] is True - + del os.environ['API__PORT'] del os.environ['CONCURRENCY__PIPELINE'] del os.environ['PLUGIN__ENABLE'] diff --git a/tests/unit_tests/plugin/test_plugin_component_filtering.py b/tests/unit_tests/plugin/test_plugin_component_filtering.py index b83667c5..c2c4fd76 100644 --- a/tests/unit_tests/plugin/test_plugin_component_filtering.py +++ b/tests/unit_tests/plugin/test_plugin_component_filtering.py @@ -1,6 +1,5 @@ """Test plugin list filtering by component kinds.""" -from datetime import datetime from unittest.mock import AsyncMock, MagicMock import pytest @@ -31,16 +30,7 @@ async def test_plugin_list_filter_by_component_kinds(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'Tool', - 'metadata': {'name': 'tool1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}], }, { 'debug': False, @@ -53,15 +43,8 @@ async def test_plugin_list_filter_by_component_kinds(): } }, 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'KnowledgeRetriever', - 'metadata': {'name': 'retriever1'} - } - } - } - ] + {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}} + ], }, { 'debug': False, @@ -73,16 +56,7 @@ async def test_plugin_list_filter_by_component_kinds(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'Command', - 'metadata': {'name': 'cmd1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'Command', 'metadata': {'name': 'cmd1'}}}}], }, { 'debug': False, @@ -94,16 +68,7 @@ async def test_plugin_list_filter_by_component_kinds(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'EventListener', - 'metadata': {'name': 'listener1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'EventListener', 'metadata': {'name': 'listener1'}}}}], }, { 'debug': False, @@ -116,23 +81,9 @@ async def test_plugin_list_filter_by_component_kinds(): } }, 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'KnowledgeRetriever', - 'metadata': {'name': 'retriever2'} - } - } - }, - { - 'manifest': { - 'manifest': { - 'kind': 'Tool', - 'metadata': {'name': 'tool2'} - } - } - } - ] + {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever2'}}}}, + {'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool2'}}}}, + ], }, ] @@ -187,16 +138,7 @@ async def test_plugin_list_filter_no_filter(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'Tool', - 'metadata': {'name': 'tool1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}], }, { 'debug': False, @@ -209,15 +151,8 @@ async def test_plugin_list_filter_no_filter(): } }, 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'KnowledgeRetriever', - 'metadata': {'name': 'retriever1'} - } - } - } - ] + {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}} + ], }, ] @@ -267,15 +202,8 @@ async def test_plugin_list_filter_empty_result(): } }, 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'KnowledgeRetriever', - 'metadata': {'name': 'retriever1'} - } - } - } - ] + {'manifest': {'manifest': {'kind': 'KnowledgeRetriever', 'metadata': {'name': 'retriever1'}}}} + ], }, ] @@ -321,16 +249,7 @@ async def test_plugin_list_filter_plugin_without_components(): } } }, - 'components': [ - { - 'manifest': { - 'manifest': { - 'kind': 'Tool', - 'metadata': {'name': 'tool1'} - } - } - } - ] + 'components': [{'manifest': {'manifest': {'kind': 'Tool', 'metadata': {'name': 'tool1'}}}}], }, { 'debug': False, @@ -342,7 +261,7 @@ async def test_plugin_list_filter_plugin_without_components(): } } }, - 'components': [] + 'components': [], }, ]