mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
Compare commits
62 Commits
feat/card_
...
validation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
485f421920 | ||
|
|
329d813577 | ||
|
|
9ce42ddcb6 | ||
|
|
608ac82762 | ||
|
|
f516fa3a4f | ||
|
|
779cf9899f | ||
|
|
63a3f323e7 | ||
|
|
4a60bdb6b6 | ||
|
|
3ceb0c6829 | ||
|
|
31f4bc1ad6 | ||
|
|
d4602bca34 | ||
|
|
5c932c66e6 | ||
|
|
6a9f7e2c16 | ||
|
|
16901bc574 | ||
|
|
3a1ea8e945 | ||
|
|
cab5f99b97 | ||
|
|
560799cc33 | ||
|
|
8275cfd140 | ||
|
|
14330741cc | ||
|
|
7d0d37cac6 | ||
|
|
d43cbf0243 | ||
|
|
74f8a500b2 | ||
|
|
937110e193 | ||
|
|
ca74fc1ba4 | ||
|
|
29a0041887 | ||
|
|
2484ddc44d | ||
|
|
d89356af65 | ||
|
|
5a90b0e06b | ||
|
|
c2af8ff9c0 | ||
|
|
93589ee381 | ||
|
|
87c5aed9e7 | ||
|
|
aa4d46fd87 | ||
|
|
aa4b5d6732 | ||
|
|
748cc68667 | ||
|
|
bb55cd7ba9 | ||
|
|
3ba727f0e4 | ||
|
|
3eaadea3e0 | ||
|
|
1a3c73bc05 | ||
|
|
adb4b29c94 | ||
|
|
af58c34c26 | ||
|
|
12c9d02145 | ||
|
|
871c4525ca | ||
|
|
3872e3e1ac | ||
|
|
ea6ed9b7fd | ||
|
|
70ec75f9a2 | ||
|
|
9e1ff7f85c | ||
|
|
91e99e2f46 | ||
|
|
59871c3118 | ||
|
|
3780a68dfa | ||
|
|
9908dc7800 | ||
|
|
84afe8551d | ||
|
|
53747fc1f0 | ||
|
|
1f855c3e7f | ||
|
|
66a0a7c9c8 | ||
|
|
25bf3ea0b3 | ||
|
|
d2c7a51e46 | ||
|
|
d38e3d9181 | ||
|
|
77be87ed40 | ||
|
|
27227aa31f | ||
|
|
1af2cb5bc2 | ||
|
|
37641f05f2 | ||
|
|
4bb0b49907 |
@@ -47,8 +47,6 @@ LangBot is an **open-source, production-grade platform** for building AI-powered
|
||||
|
||||
[→ Learn more about all features](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Practical guides: [deploy a multi-platform AI bot in 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connect DeepSeek to WeChat, Discord, and Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [run a Dify Agent in Discord, Telegram, and Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), and [build an n8n-powered chatbot](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -47,8 +47,6 @@ LangBot 是一个**开源的生产级平台**,用于构建 AI 驱动的即时
|
||||
|
||||
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
|
||||
|
||||
📍 实践指南:[5 分钟部署多平台 AI 机器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[将 DeepSeek 接入微信、企业微信与 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[让 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 构建多平台 AI 聊天机器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
|
||||
|
||||
---
|
||||
|
||||
## 快速开始
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot es una **plataforma de código abierto y grado de producción** para con
|
||||
|
||||
[→ Conocer más sobre todas las funcionalidades](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Guías prácticas: [desplegar un bot de IA multiplataforma en 5 minutos](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [conectar DeepSeek a WeChat, Discord y Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [ejecutar un Dify Agent en Discord, Telegram y Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) y [crear un chatbot con n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Inicio Rápido
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot est une **plateforme open-source de niveau production** pour créer des
|
||||
|
||||
[→ En savoir plus sur toutes les fonctionnalités](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Guides pratiques : [déployer un bot IA multiplateforme en 5 minutes](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [connecter DeepSeek à WeChat, Discord et Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [exécuter un Dify Agent dans Discord, Telegram et Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) et [créer un chatbot avec n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Démarrage Rapide
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot は、AI搭載のインスタントメッセージングボットを構
|
||||
|
||||
[→ すべての機能について詳しく見る](https://link.langbot.app/ja/docs/features)
|
||||
|
||||
📍 実践ガイド: [5分でマルチプラットフォームAIボットをデプロイ](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/)、[DeepSeekをWeChat・Discord・Telegramに接続](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/)、[Dify AgentをDiscord・Telegram・Slackで動かす](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/)、[n8n連携チャットボットを構築](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/)。
|
||||
|
||||
---
|
||||
|
||||
## クイックスタート
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot은 AI 기반 인스턴트 메시징 봇을 구축하기 위한 **오픈
|
||||
|
||||
[→ 모든 기능 자세히 보기](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 실전 가이드: [5분 만에 멀티 플랫폼 AI 봇 배포하기](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [DeepSeek를 WeChat, Discord, Telegram에 연결하기](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [Dify Agent를 Discord, Telegram, Slack에서 실행하기](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/), [n8n 기반 챗봇 만들기](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## 빠른 시작
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot — это **платформа с открытым исходным к
|
||||
|
||||
[→ Подробнее обо всех возможностях](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Практические руководства: [развернуть мультиплатформенного ИИ-бота за 5 минут](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [подключить DeepSeek к WeChat, Discord и Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [запустить Dify Agent в Discord, Telegram и Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) и [создать чат-бота на n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Быстрый старт
|
||||
|
||||
@@ -48,8 +48,6 @@ LangBot 是一個**開源的生產級平台**,用於建構 AI 驅動的即時
|
||||
|
||||
[→ 了解更多功能特性](https://link.langbot.app/zh/docs/features)
|
||||
|
||||
📍 實踐指南:[5 分鐘部署多平台 AI 機器人](https://blog.langbot.app/zh/blog/deploy-ai-bot-in-5-minutes/)、[將 DeepSeek 接入微信、企業微信與 Discord](https://blog.langbot.app/zh/blog/connect-deepseek-to-wechat/)、[讓 Dify Agent 跑在 Discord、Telegram 和 Slack 上](https://blog.langbot.app/zh/blog/dify-agent-discord-telegram-slack/),以及[用 n8n 建構多平台 AI 聊天機器人](https://blog.langbot.app/zh/blog/n8n-multi-platform-ai-chatbot/)。
|
||||
|
||||
---
|
||||
|
||||
## 快速開始
|
||||
|
||||
@@ -46,8 +46,6 @@ LangBot là một **nền tảng mã nguồn mở, cấp sản xuất** để x
|
||||
|
||||
[→ Tìm hiểu thêm về tất cả tính năng](https://link.langbot.app/en/docs/features)
|
||||
|
||||
📍 Hướng dẫn thực hành: [triển khai bot AI đa nền tảng trong 5 phút](https://blog.langbot.app/en/blog/deploy-ai-bot-in-5-minutes/), [kết nối DeepSeek với WeChat, Discord và Telegram](https://blog.langbot.app/en/blog/connect-deepseek-to-wechat/), [chạy Dify Agent trên Discord, Telegram và Slack](https://blog.langbot.app/en/blog/dify-agent-discord-telegram-slack/) và [xây dựng chatbot với n8n](https://blog.langbot.app/en/blog/n8n-multi-platform-ai-chatbot/).
|
||||
|
||||
---
|
||||
|
||||
## Bắt đầu nhanh
|
||||
|
||||
@@ -109,61 +109,6 @@ class AsyncDifyServiceClient:
|
||||
if chunk.startswith('data:'):
|
||||
yield json.loads(chunk[5:])
|
||||
|
||||
async def workflow_submit(
|
||||
self,
|
||||
form_token: str,
|
||||
workflow_run_id: str,
|
||||
inputs: dict[str, typing.Any],
|
||||
user: str,
|
||||
action: str = '',
|
||||
timeout: float = 120.0,
|
||||
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
||||
"""Submit human input to resume a paused workflow, then stream events.
|
||||
|
||||
1. POST /form/human_input/{form_token} to submit the form
|
||||
2. GET /workflow/{task_id}/events to stream the resumed workflow events
|
||||
"""
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
trust_env=True,
|
||||
timeout=timeout,
|
||||
) as client:
|
||||
# Step 1: Submit the form
|
||||
payload: dict[str, typing.Any] = {
|
||||
'inputs': inputs if isinstance(inputs, dict) else {},
|
||||
'user': user,
|
||||
'action': action,
|
||||
}
|
||||
|
||||
submit_resp = await client.post(
|
||||
f'/form/human_input/{form_token}',
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
if submit_resp.status_code != 200:
|
||||
raise DifyAPIError(f'{submit_resp.status_code} {submit_resp.text}')
|
||||
|
||||
# Step 2: Stream resumed workflow events
|
||||
async with client.stream(
|
||||
'GET',
|
||||
f'/workflow/{workflow_run_id}/events',
|
||||
headers={'Authorization': f'Bearer {self.api_key}'},
|
||||
params={'user': user},
|
||||
) as r:
|
||||
async for chunk in r.aiter_lines():
|
||||
if r.status_code != 200:
|
||||
raise DifyAPIError(f'{r.status_code} {chunk}')
|
||||
if chunk.strip() == '':
|
||||
continue
|
||||
if chunk.startswith('data:'):
|
||||
yield json.loads(chunk[5:])
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file: httpx._types.FileTypes,
|
||||
|
||||
@@ -179,6 +179,8 @@ class AdaptersRouterGroup(group.RouterGroup):
|
||||
"""Start WeChat QR code login. Returns session_id + QR code data URL."""
|
||||
import uuid
|
||||
import time
|
||||
import io
|
||||
import base64
|
||||
|
||||
from langbot.libs.openclaw_weixin_api.client import OpenClawWeixinClient, DEFAULT_BASE_URL
|
||||
|
||||
@@ -206,32 +208,60 @@ class AdaptersRouterGroup(group.RouterGroup):
|
||||
|
||||
async def run_login():
|
||||
try:
|
||||
import qrcode as qr_lib
|
||||
|
||||
def on_qrcode(qr_data_url: str, _qr_url: str):
|
||||
def _update():
|
||||
session['qr_data_url'] = qr_data_url
|
||||
session['expire_at'] = time.time() + 180
|
||||
for _attempt in range(3):
|
||||
qr_resp = await client.fetch_qrcode()
|
||||
if not qr_resp.qrcode or not qr_resp.qrcode_img_content:
|
||||
raise Exception('Failed to get QR code from server')
|
||||
|
||||
# Generate QR code image locally
|
||||
qr = qr_lib.QRCode(error_correction=qr_lib.constants.ERROR_CORRECT_L)
|
||||
qr.add_data(qr_resp.qrcode_img_content)
|
||||
qr.make(fit=True)
|
||||
img = qr.make_image(fill_color='black', back_color='white')
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format='PNG')
|
||||
b64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
||||
data_url = f'data:image/png;base64,{b64}'
|
||||
|
||||
def _update_qr():
|
||||
session['qr_data_url'] = data_url
|
||||
session['expire_at'] = time.time() + 480 # 8 minutes
|
||||
session['status'] = 'waiting'
|
||||
|
||||
loop.call_soon_threadsafe(_update)
|
||||
loop.call_soon_threadsafe(_update_qr)
|
||||
|
||||
# Poll for scan status
|
||||
deadline = loop.time() + 180
|
||||
while loop.time() < deadline:
|
||||
try:
|
||||
status_resp = await client.poll_qrcode_status(qr_resp.qrcode)
|
||||
except Exception:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
if status_resp.status == 'confirmed' and status_resp.bot_token:
|
||||
session['status'] = 'success'
|
||||
session['token'] = status_resp.bot_token
|
||||
session['base_url'] = status_resp.baseurl or client.base_url
|
||||
session['account_id'] = status_resp.ilink_bot_id or ''
|
||||
return
|
||||
|
||||
if status_resp.status == 'expired':
|
||||
break # retry with new QR code
|
||||
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
pass # timeout, retry
|
||||
|
||||
# All retries exhausted
|
||||
session['status'] = 'error'
|
||||
session['error'] = 'QR code login failed: max retries exceeded'
|
||||
|
||||
result = await client.login(
|
||||
max_retries=1,
|
||||
poll_timeout_ms=180_000,
|
||||
on_qrcode=on_qrcode,
|
||||
)
|
||||
session['status'] = 'success'
|
||||
session['token'] = result.token
|
||||
session['base_url'] = result.base_url
|
||||
session['account_id'] = result.account_id
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
if 'expired' in error_message.lower() or 'max retries exceeded' in error_message.lower():
|
||||
session['status'] = 'expired'
|
||||
session['error'] = 'QR code expired'
|
||||
else:
|
||||
session['status'] = 'error'
|
||||
session['error'] = error_message
|
||||
session['status'] = 'error'
|
||||
session['error'] = str(e)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
@@ -265,11 +295,7 @@ class AdaptersRouterGroup(group.RouterGroup):
|
||||
if not session:
|
||||
return self.http_status(404, -1, 'Session not found')
|
||||
|
||||
data = {
|
||||
'status': session['status'],
|
||||
'qr_data_url': session['qr_data_url'],
|
||||
'expire_at': session['expire_at'],
|
||||
}
|
||||
data = {'status': session['status']}
|
||||
|
||||
if session['status'] == 'success':
|
||||
data['token'] = session['token']
|
||||
@@ -279,9 +305,6 @@ class AdaptersRouterGroup(group.RouterGroup):
|
||||
elif session['status'] == 'error':
|
||||
data['error'] = session['error']
|
||||
_weixin_login_sessions.pop(session_id, None)
|
||||
elif session['status'] == 'expired':
|
||||
data['error'] = session['error']
|
||||
_weixin_login_sessions.pop(session_id, None)
|
||||
|
||||
return self.success(data=data)
|
||||
|
||||
|
||||
@@ -7,10 +7,8 @@ import httpx
|
||||
import uuid
|
||||
import os
|
||||
import posixpath
|
||||
import sqlalchemy
|
||||
|
||||
from .....core import taskmgr
|
||||
from .....entity.persistence import plugin as persistence_plugin
|
||||
from .. import group
|
||||
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||
|
||||
@@ -150,15 +148,7 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
return self.http_status(404, -1, 'plugin not found')
|
||||
|
||||
if quart.request.method == 'GET':
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_plugin.PluginSetting.config)
|
||||
.where(persistence_plugin.PluginSetting.plugin_author == author)
|
||||
.where(persistence_plugin.PluginSetting.plugin_name == plugin_name)
|
||||
)
|
||||
persisted_config = result.scalar_one_or_none()
|
||||
|
||||
config = persisted_config if persisted_config is not None else plugin['plugin_config']
|
||||
return self.success(data={'config': config})
|
||||
return self.success(data={'config': plugin['plugin_config']})
|
||||
elif quart.request.method == 'PUT':
|
||||
data = await quart.request.json
|
||||
|
||||
|
||||
@@ -140,6 +140,17 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
async def _() -> str:
|
||||
return self.success(data=await self.ap.maintenance_service.get_storage_analysis())
|
||||
|
||||
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
|
||||
py_code = await quart.request.data
|
||||
|
||||
ap = self.ap
|
||||
|
||||
return self.success(data=exec(py_code, {'ap': ap}))
|
||||
|
||||
@self.route(
|
||||
'/debug/plugin/action',
|
||||
methods=['POST'],
|
||||
|
||||
@@ -157,7 +157,7 @@ class RuntimePipeline:
|
||||
bot_message=query.resp_messages[-1],
|
||||
message=result.user_notice,
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
|
||||
is_final=[msg.is_final for msg in query.resp_messages][-1],
|
||||
is_final=[msg.is_final for msg in query.resp_messages][0],
|
||||
)
|
||||
else:
|
||||
await query.adapter.reply_message(
|
||||
|
||||
@@ -42,13 +42,9 @@ class QueryPool:
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
pipeline_uuid: typing.Optional[str] = None,
|
||||
routed_by_rule: bool = False,
|
||||
variables: typing.Optional[dict[str, typing.Any]] = None,
|
||||
) -> pipeline_query.Query:
|
||||
async with self.condition:
|
||||
query_id = self.query_id_counter
|
||||
initial_variables: dict[str, typing.Any] = {'_routed_by_rule': routed_by_rule}
|
||||
if variables:
|
||||
initial_variables.update(variables)
|
||||
query = pipeline_query.Query(
|
||||
bot_uuid=bot_uuid,
|
||||
query_id=query_id,
|
||||
@@ -57,7 +53,7 @@ class QueryPool:
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
variables=initial_variables,
|
||||
variables={'_routed_by_rule': routed_by_rule},
|
||||
resp_messages=[],
|
||||
resp_message_chain=[],
|
||||
adapter=adapter,
|
||||
|
||||
@@ -40,7 +40,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages)
|
||||
# TODO 命令与流式的兼容性问题
|
||||
if await query.adapter.is_stream_output_supported() and has_chunks:
|
||||
is_final = [msg.is_final for msg in query.resp_messages][-1]
|
||||
is_final = [msg.is_final for msg in query.resp_messages][0]
|
||||
await query.adapter.reply_message_chunk(
|
||||
message_source=query.message_event,
|
||||
bot_message=query.resp_messages[-1],
|
||||
|
||||
@@ -501,8 +501,6 @@ class PlatformManager:
|
||||
bot_entity.adapter_config,
|
||||
logger,
|
||||
)
|
||||
if hasattr(adapter_inst, 'ap'):
|
||||
adapter_inst.ap = self.ap
|
||||
|
||||
# 如果 adapter 支持 set_bot_uuid 方法,设置 bot_uuid(用于统一 webhook)
|
||||
if hasattr(adapter_inst, 'set_bot_uuid'):
|
||||
|
||||
@@ -3,7 +3,6 @@ import typing
|
||||
import asyncio
|
||||
import traceback
|
||||
import datetime
|
||||
import json
|
||||
|
||||
import aiocqhttp
|
||||
import pydantic
|
||||
@@ -294,29 +293,6 @@ class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConvert
|
||||
elif msg.type == 'dice':
|
||||
face_id = msg.data['result']
|
||||
yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子'))
|
||||
elif msg.type == 'json':
|
||||
try:
|
||||
raw = msg.data.get('data', {})
|
||||
if isinstance(raw, str):
|
||||
raw = json.loads(raw)
|
||||
if isinstance(raw, dict):
|
||||
_meta = raw.get('meta', {}) or {}
|
||||
if isinstance(_meta, dict):
|
||||
_detail = _meta.get('detail_1') or _meta.get('music') or _meta.get('news') or {}
|
||||
else:
|
||||
_detail = {}
|
||||
if isinstance(_detail, dict):
|
||||
preview = _detail.get('preview', '')
|
||||
title = _detail.get('desc', '') or _detail.get('title', '')
|
||||
url = _detail.get('qqdocurl', '') or _detail.get('jumpUrl', '')
|
||||
else:
|
||||
preview = title = url = ''
|
||||
text = ' '.join([f'[{raw.get("app", "")}]', preview, title, url]).strip()
|
||||
yiri_msg_list.append(platform_message.Plain(text=text or '[收到一张JSON卡片]'))
|
||||
else:
|
||||
yiri_msg_list.append(platform_message.Plain(text=str(raw)))
|
||||
except Exception:
|
||||
yiri_msg_list.append(platform_message.Plain(text='[收到一张JSON卡片]'))
|
||||
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
|
||||
class AESCipher(object):
|
||||
@@ -771,7 +770,6 @@ CARD_ID_CACHE_MAX_LIFETIME = 20 * 60 # 20分钟
|
||||
class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: lark_oapi.ws.Client = pydantic.Field(exclude=True)
|
||||
api_client: lark_oapi.Client = pydantic.Field(exclude=True)
|
||||
ap: typing.Any = pydantic.Field(exclude=True, default=None)
|
||||
|
||||
bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识
|
||||
lark_tenant_key: str = pydantic.Field(exclude=True, default='') # 飞书企业key
|
||||
@@ -794,16 +792,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
pending_monitoring_msg: dict[str, str]
|
||||
# Final: reply Lark message ID → (monitoring_message_id, timestamp) (used by feedback callbacks)
|
||||
reply_to_monitoring_msg: dict[str, tuple[str, float]]
|
||||
reply_message_card_ids: dict[str, str]
|
||||
card_sequence_dict: dict[str, int]
|
||||
# card_id → set of source message ids registered against it (for cleanup)
|
||||
card_id_to_source_ids: dict[str, set[str]]
|
||||
# card_id → current streaming_txt content cache (needed for full aupdate during resume transition)
|
||||
card_streaming_text: dict[str, str]
|
||||
# card_id → pre-pause streaming_txt text (captured when resume first chunk arrives)
|
||||
card_pre_pause_text: dict[str, str]
|
||||
# set of card_ids that have already transitioned from "buttons visible" to "resume layout"
|
||||
card_resume_transitioned: set[str]
|
||||
_MONITORING_MAPPING_TTL = 600 # 10 minutes
|
||||
|
||||
seq: int # 用于在发送卡片消息中识别消息顺序,直接以seq作为标识
|
||||
@@ -824,134 +812,11 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
def sync_on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):
|
||||
asyncio.create_task(on_message(event))
|
||||
|
||||
def schedule_on_app_loop(coro):
|
||||
"""Run a coroutine on the application event loop from sync callbacks."""
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.ap.event_loop)
|
||||
|
||||
def sync_on_card_action(event):
|
||||
try:
|
||||
action_value_raw = getattr(getattr(event.event, 'action', None), 'value', {})
|
||||
# Parse JSON string values (from form action buttons)
|
||||
if isinstance(action_value_raw, str):
|
||||
try:
|
||||
action_value_obj = json.loads(action_value_raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
action_value_obj = {}
|
||||
else:
|
||||
action_value_obj = action_value_raw if isinstance(action_value_raw, dict) else {}
|
||||
action_value_obj = getattr(getattr(event.event, 'action', None), 'value', {})
|
||||
action_value = action_value_obj.get('feedback', '') if isinstance(action_value_obj, dict) else ''
|
||||
|
||||
# Handle Dify form action button clicks
|
||||
if isinstance(action_value_obj, dict) and action_value_obj.get('form_action'):
|
||||
form_token = action_value_obj.get('form_token', '')
|
||||
workflow_run_id = action_value_obj.get('workflow_run_id', '')
|
||||
action_id = action_value_obj.get('action_id', '')
|
||||
session_key = action_value_obj.get('session_key', '')
|
||||
|
||||
if session_key.startswith('group_') or session_key.startswith('g:'):
|
||||
launcher_type = provider_session.LauncherTypes.GROUP
|
||||
launcher_id = (
|
||||
session_key.split(':', 1)[1]
|
||||
if session_key.startswith('g:')
|
||||
else session_key[len('group_') :]
|
||||
)
|
||||
else:
|
||||
launcher_type = provider_session.LauncherTypes.PERSON
|
||||
launcher_id = (
|
||||
session_key.split(':', 1)[1]
|
||||
if session_key.startswith('p:')
|
||||
else session_key[len('person_') :]
|
||||
)
|
||||
|
||||
# Find the bot entity to get bot_uuid and pipeline_uuid
|
||||
bot_uuid = ''
|
||||
pipeline_uuid = None
|
||||
for bot in self.ap.platform_mgr.bots:
|
||||
if bot.adapter is self:
|
||||
bot_uuid = bot.bot_entity.uuid
|
||||
pipeline_uuid = bot.bot_entity.use_pipeline_uuid
|
||||
break
|
||||
|
||||
form_action_data = {
|
||||
'form_token': form_token,
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'action_id': action_id,
|
||||
'user': f'{launcher_type.value}_{launcher_id}',
|
||||
'inputs': {},
|
||||
}
|
||||
|
||||
context = getattr(event.event, 'context', None)
|
||||
open_message_id = getattr(context, 'open_message_id', None)
|
||||
source_time = datetime.datetime.now()
|
||||
event_time = source_time.timestamp()
|
||||
action_text = action_value_obj.get('action_id', 'confirm')
|
||||
message_chain = platform_message.MessageChain(
|
||||
[platform_message.Plain(text=f'[Form Action: {action_text}]')]
|
||||
)
|
||||
if open_message_id:
|
||||
message_chain.insert(
|
||||
0,
|
||||
platform_message.Source(
|
||||
id=open_message_id,
|
||||
time=source_time,
|
||||
),
|
||||
)
|
||||
|
||||
operator = getattr(event.event, 'operator', None)
|
||||
user_id = (
|
||||
getattr(operator, 'open_id', None) or getattr(operator, 'user_id', None) or str(launcher_id)
|
||||
)
|
||||
|
||||
if launcher_type == provider_session.LauncherTypes.GROUP:
|
||||
synthetic_event = platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=user_id,
|
||||
member_name='',
|
||||
permission=platform_entities.Permission.Member,
|
||||
group=platform_entities.Group(
|
||||
id=launcher_id,
|
||||
name='',
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event_time,
|
||||
source_platform_object=event,
|
||||
)
|
||||
else:
|
||||
synthetic_event = platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=user_id,
|
||||
nickname='',
|
||||
remark='',
|
||||
),
|
||||
message_chain=message_chain,
|
||||
time=event_time,
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
async def add_form_action_query():
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=bot_uuid,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=user_id,
|
||||
message_event=synthetic_event,
|
||||
message_chain=message_chain,
|
||||
adapter=self,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
variables={
|
||||
'_dify_form_action': form_action_data,
|
||||
'_routed_by_rule': True,
|
||||
},
|
||||
)
|
||||
|
||||
schedule_on_app_loop(add_form_action_query())
|
||||
|
||||
from lark_oapi.event.callback.model.p2_card_action_trigger import P2CardActionTriggerResponse
|
||||
|
||||
return P2CardActionTriggerResponse({'toast': {'type': 'success', 'content': '操作成功'}})
|
||||
|
||||
if action_value == '有帮助':
|
||||
feedback_type = 1
|
||||
elif action_value == '无帮助':
|
||||
@@ -992,14 +857,17 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
)
|
||||
|
||||
if platform_events.FeedbackEvent in self.listeners:
|
||||
schedule_on_app_loop(self.listeners[platform_events.FeedbackEvent](feedback_event, self))
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
asyncio.create_task(self.listeners[platform_events.FeedbackEvent](feedback_event, self))
|
||||
else:
|
||||
loop.run_until_complete(self.listeners[platform_events.FeedbackEvent](feedback_event, self))
|
||||
|
||||
from lark_oapi.event.callback.model.p2_card_action_trigger import P2CardActionTriggerResponse
|
||||
|
||||
return P2CardActionTriggerResponse({'toast': {'type': 'success', 'content': '感谢您的反馈'}})
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
schedule_on_app_loop(self.logger.error(f'Error in lark card action callback: {traceback.format_exc()}'))
|
||||
asyncio.create_task(self.logger.error(f'Error in lark card action callback: {traceback.format_exc()}'))
|
||||
from lark_oapi.event.callback.model.p2_card_action_trigger import P2CardActionTriggerResponse
|
||||
|
||||
return P2CardActionTriggerResponse({'toast': {'type': 'error', 'content': '反馈处理失败'}})
|
||||
@@ -1025,12 +893,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
card_id_dict={},
|
||||
pending_monitoring_msg={},
|
||||
reply_to_monitoring_msg={},
|
||||
reply_message_card_ids={},
|
||||
card_sequence_dict={},
|
||||
card_id_to_source_ids={},
|
||||
card_streaming_text={},
|
||||
card_pre_pause_text={},
|
||||
card_resume_transitioned=set(),
|
||||
seq=1,
|
||||
listeners={},
|
||||
quart_app=quart_app,
|
||||
@@ -1270,33 +1132,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
for k in expired:
|
||||
del self.reply_to_monitoring_msg[k]
|
||||
|
||||
def _next_card_sequence(self, card_id: str, suggested: int = 1) -> int:
|
||||
"""Return the next strictly increasing sequence for a card update."""
|
||||
current = self.card_sequence_dict.get(card_id, 0)
|
||||
next_seq = max(current + 1, suggested)
|
||||
self.card_sequence_dict[card_id] = next_seq
|
||||
return next_seq
|
||||
|
||||
def _register_card_for_source(self, card_id: str, *source_ids: str) -> None:
|
||||
"""Register a card_id under one or more source message ids."""
|
||||
bucket = self.card_id_to_source_ids.setdefault(card_id, set())
|
||||
for sid in source_ids:
|
||||
if not sid:
|
||||
continue
|
||||
self.reply_message_card_ids[sid] = card_id
|
||||
bucket.add(sid)
|
||||
|
||||
def _drop_card_state(self, card_id: str) -> None:
|
||||
"""Pop all per-card state for the given card_id."""
|
||||
if not card_id:
|
||||
return
|
||||
for sid in self.card_id_to_source_ids.pop(card_id, set()):
|
||||
self.reply_message_card_ids.pop(sid, None)
|
||||
self.card_sequence_dict.pop(card_id, None)
|
||||
self.card_streaming_text.pop(card_id, None)
|
||||
self.card_pre_pause_text.pop(card_id, None)
|
||||
self.card_resume_transitioned.discard(card_id)
|
||||
|
||||
async def create_card_id(self, message_id):
|
||||
try:
|
||||
# self.logger.debug('飞书支持stream输出,创建卡片......')
|
||||
@@ -1492,7 +1327,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
self.card_id_dict[message_id] = response.data.card_id
|
||||
|
||||
card_id = response.data.card_id
|
||||
self.card_sequence_dict[card_id] = 0
|
||||
return card_id
|
||||
|
||||
except Exception as e:
|
||||
@@ -1505,12 +1339,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
"""
|
||||
# message_id = event.message_chain.message_id
|
||||
|
||||
source_message_id = str(event.message_chain.message_id)
|
||||
existing_card_id = self.reply_message_card_ids.get(source_message_id)
|
||||
if existing_card_id:
|
||||
self.card_id_dict[message_id] = existing_card_id
|
||||
return True
|
||||
|
||||
card_id = await self.create_card_id(message_id)
|
||||
content = {
|
||||
'type': 'card',
|
||||
@@ -1549,16 +1377,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
user_msg_id = event.message_chain.message_id
|
||||
reply_msg_id = getattr(response.data, 'message_id', None)
|
||||
monitoring_msg_id = self.pending_monitoring_msg.pop(user_msg_id, None)
|
||||
# Register the card under both the user-incoming msg id (so a
|
||||
# second reply_message_first_chunk for the same user message
|
||||
# reuses this card) AND the bot-reply msg id (so a synthetic
|
||||
# event from a form-button callback — whose Source.id equals
|
||||
# the bot's card message id — hits the same card and renders
|
||||
# the resume content into it).
|
||||
if reply_msg_id:
|
||||
self._register_card_for_source(card_id, str(user_msg_id), str(reply_msg_id))
|
||||
else:
|
||||
self._register_card_for_source(card_id, str(user_msg_id))
|
||||
if reply_msg_id and monitoring_msg_id:
|
||||
self.reply_to_monitoring_msg[reply_msg_id] = (monitoring_msg_id, time.time())
|
||||
self._cleanup_monitoring_mapping()
|
||||
@@ -1567,93 +1385,6 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
|
||||
return True
|
||||
|
||||
async def _open_new_form_card(
|
||||
self,
|
||||
message_id: str,
|
||||
message_source: platform_events.MessageEvent,
|
||||
form_data: dict,
|
||||
) -> str | None:
|
||||
"""Spawn a fresh card to host a re-paused human-input prompt.
|
||||
|
||||
Creates a new card_id (rebinding ``self.card_id_dict[message_id]``),
|
||||
replies it to the current incoming message so it appears as the next
|
||||
step in the chat, registers the new reply_msg_id so subsequent button
|
||||
callbacks resolve back to it, and renders the prompt + buttons on it.
|
||||
|
||||
Returns the new card_id, or ``None`` if creation failed (caller is
|
||||
responsible for falling back to in-place update so the workflow
|
||||
remains continuable).
|
||||
"""
|
||||
source_message_id = getattr(message_source.message_chain, 'message_id', None)
|
||||
if not source_message_id:
|
||||
await self.logger.error('Cannot open new form card: source message_id missing')
|
||||
return None
|
||||
|
||||
try:
|
||||
new_card_id = await self.create_card_id(message_id)
|
||||
except Exception:
|
||||
await self.logger.error(f'Failed to create new form card: {traceback.format_exc()}')
|
||||
return None
|
||||
|
||||
tenant_key = (
|
||||
message_source.source_platform_object.header.tenant_key if message_source.source_platform_object else None
|
||||
)
|
||||
app_access_token = self.get_app_access_token()
|
||||
tenant_access_token = self.get_tenant_access_token(tenant_key)
|
||||
req_opt: RequestOption = (
|
||||
RequestOption.builder()
|
||||
.app_ticket(self.app_ticket)
|
||||
.tenant_key(tenant_key)
|
||||
.app_access_token(app_access_token)
|
||||
.tenant_access_token(tenant_access_token)
|
||||
.build()
|
||||
)
|
||||
|
||||
content = {
|
||||
'type': 'card',
|
||||
'data': {'card_id': new_card_id, 'template_variable': {'content': ''}},
|
||||
}
|
||||
request: ReplyMessageRequest = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(str(source_message_id))
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(json.dumps(content))
|
||||
.msg_type('interactive')
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
try:
|
||||
response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(request, req_opt)
|
||||
except Exception:
|
||||
await self.logger.error(f'Failed to send new form card: {traceback.format_exc()}')
|
||||
return None
|
||||
|
||||
if not response.success():
|
||||
await self.logger.error(
|
||||
f'Failed to send new form card: code={response.code}, msg={response.msg}, '
|
||||
f'log_id={response.get_log_id()}'
|
||||
)
|
||||
return None
|
||||
|
||||
reply_msg_id = getattr(response.data, 'message_id', None)
|
||||
if reply_msg_id:
|
||||
self._register_card_for_source(new_card_id, str(source_message_id), str(reply_msg_id))
|
||||
|
||||
sequence = self._next_card_sequence(new_card_id, 1)
|
||||
await self._update_card_layout(
|
||||
card_id=new_card_id,
|
||||
message_source=message_source,
|
||||
text_message='',
|
||||
sequence=sequence,
|
||||
form_data=form_data,
|
||||
show_form_prompt=True,
|
||||
)
|
||||
return new_card_id
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
@@ -1773,492 +1504,45 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
):
|
||||
"""
|
||||
回复消息变成更新卡片消息
|
||||
|
||||
Supports Dify form-action resume: when the runner yields a chunk with
|
||||
``_resume_from_form=True``, the card transitions from buttons to a
|
||||
grey "已选择" notice and a new ``streaming_txt_resume`` element is added
|
||||
for subsequent resume chunks to stream into.
|
||||
|
||||
When ``_open_new_card=True`` on the final chunk, the existing card is
|
||||
left as-is and the pipeline will create a new card (with fresh form
|
||||
buttons) for the re-pause.
|
||||
"""
|
||||
# self.seq += 1
|
||||
message_id = bot_message.resp_message_id
|
||||
msg_seq = bot_message.msg_sequence
|
||||
if msg_seq % 8 == 0 or is_final:
|
||||
text_elements, media_items = await self.message_converter.yiri2target(message, self.api_client)
|
||||
|
||||
form_data = getattr(bot_message, '_form_data', None)
|
||||
resume_from = getattr(bot_message, '_resume_from_form', False)
|
||||
action_title = getattr(bot_message, '_resume_action_title', '')
|
||||
resume_node_title = getattr(bot_message, '_resume_node_title', '')
|
||||
open_new_card = getattr(bot_message, '_open_new_card', False)
|
||||
if action_title:
|
||||
if resume_node_title:
|
||||
selected_notice = f'**{resume_node_title}**\n已选择:{action_title}'
|
||||
else:
|
||||
selected_notice = f'**已选择**:{action_title}'
|
||||
else:
|
||||
selected_notice = ''
|
||||
text_message = ''
|
||||
if text_elements:
|
||||
parts = []
|
||||
for paragraph in text_elements:
|
||||
para_text = ''.join(ele['text'] for ele in paragraph if ele['tag'] in ('text', 'md'))
|
||||
if para_text:
|
||||
parts.append(para_text)
|
||||
text_message = '\n\n'.join(parts)
|
||||
|
||||
# ── decide whether this chunk needs a card update ────────────────────
|
||||
card_id = self.card_id_dict.get(message_id)
|
||||
if not card_id:
|
||||
return
|
||||
# content = {
|
||||
# 'type': 'card_json',
|
||||
# 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}},
|
||||
# }
|
||||
|
||||
# ── convert message chain → text ─────────────────────────────────────
|
||||
text_elements, media_items = await self.message_converter.yiri2target(message, self.api_client)
|
||||
|
||||
text_message = ''
|
||||
if text_elements:
|
||||
parts = []
|
||||
for paragraph in text_elements:
|
||||
para_text = ''.join(ele['text'] for ele in paragraph if ele['tag'] in ('text', 'md'))
|
||||
if para_text:
|
||||
parts.append(para_text)
|
||||
text_message = '\n\n'.join(parts)
|
||||
|
||||
tenant_key = (
|
||||
message_source.source_platform_object.header.tenant_key if message_source.source_platform_object else None
|
||||
)
|
||||
app_access_token = self.get_app_access_token()
|
||||
tenant_access_token = self.get_tenant_access_token(tenant_key)
|
||||
req_opt: RequestOption = (
|
||||
RequestOption.builder()
|
||||
.app_ticket(self.app_ticket)
|
||||
.tenant_key(tenant_key)
|
||||
.app_access_token(app_access_token)
|
||||
.tenant_access_token(tenant_access_token)
|
||||
.build()
|
||||
)
|
||||
|
||||
card_sequence = self._next_card_sequence(card_id, msg_seq)
|
||||
|
||||
# ── RESUME: first chunk after button click ───────────────────────────
|
||||
if resume_from and card_id not in self.card_resume_transitioned:
|
||||
# Transition the card from the form state into resume mode.
|
||||
# Preserve the text that was shown before the pause, and seed the
|
||||
# resume placeholder with the current resume content if we already
|
||||
# have any on the first yielded chunk.
|
||||
pre_pause_text = self.card_pre_pause_text.get(card_id) or self.card_streaming_text.get(card_id, '')
|
||||
initial_resume_text = text_message or '\u200b'
|
||||
await self._update_card_layout(
|
||||
card_id=card_id,
|
||||
message_source=message_source,
|
||||
text_message=pre_pause_text,
|
||||
sequence=card_sequence,
|
||||
form_data=None,
|
||||
notice_text=selected_notice,
|
||||
resume_placeholder_text=initial_resume_text,
|
||||
)
|
||||
self.card_resume_transitioned.add(card_id)
|
||||
self.card_pre_pause_text[card_id] = pre_pause_text
|
||||
self.card_streaming_text[card_id] = text_message
|
||||
if not is_final:
|
||||
return
|
||||
|
||||
# ── RESUME: subsequent chunks → full card update ─────────────────────
|
||||
if resume_from and card_id in self.card_resume_transitioned:
|
||||
cached = self.card_streaming_text.get(card_id, '')
|
||||
if text_message != cached:
|
||||
self.card_streaming_text[card_id] = text_message
|
||||
pre_pause_text = self.card_pre_pause_text.get(card_id, '')
|
||||
await self._update_card_layout(
|
||||
card_id=card_id,
|
||||
message_source=message_source,
|
||||
text_message=pre_pause_text,
|
||||
sequence=card_sequence,
|
||||
form_data=None,
|
||||
notice_text=selected_notice,
|
||||
resume_placeholder_text=text_message,
|
||||
)
|
||||
if not is_final:
|
||||
return
|
||||
|
||||
# ── NORMAL streaming (non-resume): update streaming_txt in-place ──────
|
||||
if not resume_from and (msg_seq % 8 == 0 or is_final):
|
||||
cached = self.card_streaming_text.get(card_id)
|
||||
if text_message != cached:
|
||||
self.card_streaming_text[card_id] = text_message
|
||||
request: ContentCardElementRequest = (
|
||||
ContentCardElementRequest.builder()
|
||||
.card_id(card_id)
|
||||
.element_id('streaming_txt')
|
||||
.request_body(
|
||||
ContentCardElementRequestBody.builder().content(text_message).sequence(card_sequence).build()
|
||||
)
|
||||
request: ContentCardElementRequest = (
|
||||
ContentCardElementRequest.builder()
|
||||
.card_id(self.card_id_dict[message_id])
|
||||
.element_id('streaming_txt')
|
||||
.request_body(
|
||||
ContentCardElementRequestBody.builder()
|
||||
# .uuid("a0d69e20-1dd1-458b-k525-dfeca4015204")
|
||||
.content(text_message)
|
||||
.sequence(msg_seq)
|
||||
.build()
|
||||
)
|
||||
response: ContentCardElementResponse = await self.api_client.cardkit.v1.card_element.acontent(
|
||||
request, req_opt
|
||||
)
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f'client.cardkit.v1.card_element.acontent failed, code: {response.code}, '
|
||||
f'msg: {response.msg}, log_id: {response.get_log_id()}, '
|
||||
f'resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
# ── FINAL chunk: full card layout update ─────────────────────────────
|
||||
if is_final:
|
||||
final_seq = self._next_card_sequence(card_id, card_sequence + 1)
|
||||
pre_pause = self.card_pre_pause_text.get(card_id, text_message)
|
||||
resume_cached = self.card_streaming_text.get(card_id, '')
|
||||
if form_data:
|
||||
if open_new_card:
|
||||
# The old card has already been laid out into resume mode
|
||||
# by the resume-transition block above (notice + resume
|
||||
# placeholder). Finalise it as a frozen step snapshot and
|
||||
# spawn a brand-new card to host the next human-input
|
||||
# prompt — each step stays visible as its own card in the
|
||||
# chat history.
|
||||
new_card_id = await self._open_new_form_card(message_id, message_source, form_data)
|
||||
if new_card_id is None:
|
||||
# Fallback: keep the existing in-place behaviour so the
|
||||
# workflow remains continuable even if creating the
|
||||
# new card failed.
|
||||
await self._update_card_layout(
|
||||
card_id=card_id,
|
||||
message_source=message_source,
|
||||
text_message=pre_pause,
|
||||
sequence=final_seq,
|
||||
form_data=form_data,
|
||||
resume_placeholder_text=resume_cached,
|
||||
show_form_prompt=True,
|
||||
)
|
||||
self.card_streaming_text.pop(card_id, None)
|
||||
self.card_pre_pause_text.pop(card_id, None)
|
||||
else:
|
||||
# The old card is now a frozen snapshot; let go of its
|
||||
# streaming-side state but keep its source registrations
|
||||
# intact (no _drop_card_state) so historical button
|
||||
# callbacks aimed at it can still be matched if needed.
|
||||
self.card_streaming_text.pop(card_id, None)
|
||||
self.card_pre_pause_text.pop(card_id, None)
|
||||
self.card_resume_transitioned.discard(card_id)
|
||||
else:
|
||||
# Initial pause path: render prompt + buttons in place on
|
||||
# the current card.
|
||||
await self._update_card_layout(
|
||||
card_id=card_id,
|
||||
message_source=message_source,
|
||||
text_message=text_message,
|
||||
sequence=final_seq,
|
||||
form_data=form_data,
|
||||
show_form_prompt=True,
|
||||
)
|
||||
# The human-input prompt itself is rendered as buttons only
|
||||
# on Lark, so do not keep the hidden fallback text around;
|
||||
# otherwise it will resurface after the button click.
|
||||
self.card_streaming_text[card_id] = ''
|
||||
self.card_pre_pause_text[card_id] = ''
|
||||
else:
|
||||
# Normal finish: keep pre-pause + resume content visible,
|
||||
# remove buttons/notice, drop the resume placeholder.
|
||||
await self._update_card_layout(
|
||||
card_id=card_id,
|
||||
message_source=message_source,
|
||||
text_message=pre_pause,
|
||||
sequence=final_seq,
|
||||
form_data=None,
|
||||
notice_text=selected_notice if resume_from else '',
|
||||
resume_placeholder_text=resume_cached,
|
||||
)
|
||||
self._drop_card_state(card_id)
|
||||
self.card_id_dict.pop(message_id, None)
|
||||
|
||||
# ── media (images / files) appended at the end ───────────────────────
|
||||
if is_final and media_items:
|
||||
for media in media_items:
|
||||
media_request: ReplyMessageRequest = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(message_source.message_chain.message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(json.dumps(media['content']))
|
||||
.msg_type(media['msg_type'])
|
||||
.reply_in_thread(False)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
media_response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(
|
||||
media_request, req_opt
|
||||
)
|
||||
if not media_response.success():
|
||||
raise Exception(
|
||||
f'client.im.v1.message.reply ({media["msg_type"]}) failed, code: {media_response.code}, msg: {media_response.msg}, log_id: {media_response.get_log_id()}'
|
||||
)
|
||||
|
||||
async def _add_form_buttons_to_card(
|
||||
self,
|
||||
card_id: str,
|
||||
message_source: platform_events.MessageEvent,
|
||||
form_data: dict,
|
||||
text_message: str = '',
|
||||
sequence: int = 1,
|
||||
):
|
||||
"""Update the entire card to include form action buttons.
|
||||
|
||||
Uses card.aupdate to replace the card JSON with a template that
|
||||
includes the streaming text content plus interactive buttons.
|
||||
"""
|
||||
await self._update_card_layout(
|
||||
card_id=card_id,
|
||||
message_source=message_source,
|
||||
text_message=text_message,
|
||||
sequence=sequence,
|
||||
form_data=form_data,
|
||||
)
|
||||
|
||||
async def _remove_form_buttons_from_card(
|
||||
self,
|
||||
card_id: str,
|
||||
message_source: platform_events.MessageEvent,
|
||||
text_message: str = '',
|
||||
sequence: int = 1,
|
||||
):
|
||||
"""Replace the human-input card layout with the plain final layout."""
|
||||
await self._update_card_layout(
|
||||
card_id=card_id,
|
||||
message_source=message_source,
|
||||
text_message=text_message,
|
||||
sequence=sequence,
|
||||
form_data=None,
|
||||
)
|
||||
|
||||
async def _update_card_layout(
|
||||
self,
|
||||
card_id: str,
|
||||
message_source: platform_events.MessageEvent,
|
||||
text_message: str = '',
|
||||
sequence: int = 1,
|
||||
form_data: dict | None = None,
|
||||
notice_text: str = '',
|
||||
resume_placeholder_text: str = '',
|
||||
show_form_prompt: bool = True,
|
||||
):
|
||||
"""Update the entire card layout.
|
||||
|
||||
• form_data → show interactive buttons (initial Dify pause)
|
||||
• notice_text → replace buttons with a grey "已选择" notice (resume transition)
|
||||
• resume_placeholder_text → add a streaming_txt_resume markdown element
|
||||
"""
|
||||
form_data = form_data or {}
|
||||
actions = form_data.get('actions', [])
|
||||
form_token = form_data.get('form_token', '')
|
||||
workflow_run_id = form_data.get('workflow_run_id', '')
|
||||
node_title = form_data.get('node_title', '') or 'Human Input Required'
|
||||
form_content = form_data.get('form_content', '')
|
||||
|
||||
# When form_data is set, the visible content is rendered inside the
|
||||
# interactive container, so the top streaming text should stay empty
|
||||
# to avoid duplicate text above the action area.
|
||||
#
|
||||
# For resume notice state, keep the existing text visible in the card
|
||||
# and only add the grey "selected" notice below it.
|
||||
if form_data:
|
||||
render_text_message = ''
|
||||
else:
|
||||
render_text_message = text_message
|
||||
|
||||
# Determine session key from message source
|
||||
if isinstance(message_source, platform_events.GroupMessage):
|
||||
session_key = f'group_{message_source.group.id}'
|
||||
else:
|
||||
session_key = f'person_{message_source.sender.id}'
|
||||
|
||||
# Build button elements matching the existing card template's thumbsup/down format
|
||||
action_buttons = []
|
||||
for action in actions:
|
||||
action_id = action.get('id', '')
|
||||
action_title = action.get('title', action_id)
|
||||
button_style = action.get('button_style', 'default')
|
||||
|
||||
if button_style == 'primary':
|
||||
lark_button_type = 'primary'
|
||||
elif button_style == 'danger':
|
||||
lark_button_type = 'danger'
|
||||
else:
|
||||
lark_button_type = 'default'
|
||||
|
||||
action_buttons.append(
|
||||
{
|
||||
'tag': 'button',
|
||||
'text': {'tag': 'plain_text', 'content': action_title},
|
||||
'type': lark_button_type,
|
||||
'width': 'fill',
|
||||
'size': 'medium',
|
||||
'hover_tips': {'tag': 'plain_text', 'content': action_title},
|
||||
'behaviors': [
|
||||
{
|
||||
'type': 'callback',
|
||||
'value': {
|
||||
'form_action': True,
|
||||
'form_token': form_token,
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'action_id': action_id,
|
||||
'session_key': session_key,
|
||||
},
|
||||
}
|
||||
],
|
||||
'margin': '0px 0px 0px 0px',
|
||||
}
|
||||
.build()
|
||||
)
|
||||
|
||||
interactive_elements = []
|
||||
if form_data:
|
||||
if show_form_prompt:
|
||||
interactive_elements = [
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': f'**[Human Input Required] {node_title}**',
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '0px 0px 4px 0px',
|
||||
}
|
||||
]
|
||||
if form_content:
|
||||
interactive_elements.append(
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': form_content,
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '0px 0px 8px 0px',
|
||||
}
|
||||
)
|
||||
interactive_elements.append(
|
||||
{
|
||||
'tag': 'column_set',
|
||||
'horizontal_spacing': '8px',
|
||||
'horizontal_align': 'left',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
'columns': [
|
||||
{
|
||||
'tag': 'column',
|
||||
'width': 'weighted',
|
||||
'elements': [btn],
|
||||
'padding': '0px 0px 0px 0px',
|
||||
}
|
||||
for btn in action_buttons
|
||||
],
|
||||
}
|
||||
)
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
# self.seq = 1 # 消息回复结束之后重置seq
|
||||
self.card_id_dict.pop(message_id) # 清理已经使用过的卡片
|
||||
|
||||
# Build the full card JSON with buttons, same structure as create_card_id
|
||||
# ── mid_section: either form buttons, resume notice, or empty ──
|
||||
mid_section_elements = []
|
||||
if form_data:
|
||||
mid_section_elements = [
|
||||
{
|
||||
'tag': 'interactive_container',
|
||||
'margin': '12px 0px 8px 0px',
|
||||
'padding': '12px 12px 12px 12px',
|
||||
'has_border': True,
|
||||
'elements': interactive_elements,
|
||||
},
|
||||
{'tag': 'hr', 'margin': '0px 0px 0px 0px'},
|
||||
]
|
||||
elif notice_text:
|
||||
mid_section_elements = [
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': notice_text,
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '8px 0px 4px 0px',
|
||||
'text_color': 'grey',
|
||||
},
|
||||
{'tag': 'hr', 'margin': '0px 0px 0px 0px'},
|
||||
]
|
||||
|
||||
# ── resume placeholder element (empty, filled via acontent on each chunk) ──
|
||||
resume_elements = []
|
||||
if resume_placeholder_text:
|
||||
resume_elements = [
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': resume_placeholder_text,
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
'element_id': 'streaming_txt_resume',
|
||||
},
|
||||
]
|
||||
|
||||
card_data = {
|
||||
'schema': '2.0',
|
||||
'config': {
|
||||
'update_multi': True,
|
||||
'streaming_mode': False,
|
||||
},
|
||||
'body': {
|
||||
'direction': 'vertical',
|
||||
'padding': '12px 12px 12px 12px',
|
||||
'elements': [
|
||||
{
|
||||
'tag': 'div',
|
||||
'text': {
|
||||
'tag': 'plain_text',
|
||||
'content': 'LangBot',
|
||||
'text_size': 'normal',
|
||||
'text_align': 'left',
|
||||
'text_color': 'default',
|
||||
},
|
||||
'icon': {
|
||||
'tag': 'custom_icon',
|
||||
'img_key': 'img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg',
|
||||
},
|
||||
},
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': render_text_message,
|
||||
'text_align': 'left',
|
||||
'text_size': 'normal',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
'element_id': 'streaming_txt',
|
||||
},
|
||||
*mid_section_elements,
|
||||
*resume_elements,
|
||||
{
|
||||
'tag': 'column_set',
|
||||
'horizontal_spacing': '12px',
|
||||
'horizontal_align': 'right',
|
||||
'columns': [
|
||||
{
|
||||
'tag': 'column',
|
||||
'width': 'weighted',
|
||||
'elements': [
|
||||
{
|
||||
'tag': 'markdown',
|
||||
'content': '<font color="grey-600">以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看</font>',
|
||||
'text_align': 'left',
|
||||
'text_size': 'notation',
|
||||
'margin': '4px 0px 0px 0px',
|
||||
'icon': {
|
||||
'tag': 'standard_icon',
|
||||
'token': 'robot_outlined',
|
||||
'color': 'grey',
|
||||
},
|
||||
}
|
||||
],
|
||||
'padding': '0px 0px 0px 0px',
|
||||
'direction': 'vertical',
|
||||
'horizontal_spacing': '8px',
|
||||
'vertical_spacing': '8px',
|
||||
'horizontal_align': 'left',
|
||||
'vertical_align': 'top',
|
||||
'margin': '0px 0px 0px 0px',
|
||||
'weight': 1,
|
||||
}
|
||||
],
|
||||
'margin': '0px 0px 4px 0px',
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
tenant_key = (
|
||||
message_source.source_platform_object.header.tenant_key
|
||||
if message_source.source_platform_object
|
||||
@@ -2274,27 +1558,39 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
.tenant_access_token(tenant_access_token)
|
||||
.build()
|
||||
)
|
||||
# 发起请求
|
||||
response: ContentCardElementResponse = self.api_client.cardkit.v1.card_element.content(request, req_opt)
|
||||
|
||||
request: UpdateCardRequest = (
|
||||
UpdateCardRequest.builder()
|
||||
.card_id(card_id)
|
||||
.request_body(
|
||||
UpdateCardRequestBody.builder()
|
||||
.sequence(sequence)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.card(Card.builder().type('card_json').data(json.dumps(card_data)).build())
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response: UpdateCardResponse = await self.api_client.cardkit.v1.card.aupdate(request, req_opt)
|
||||
# 处理失败返回
|
||||
if not response.success():
|
||||
await self.logger.error(
|
||||
f'Failed to update lark card with form buttons: code={response.code}, msg={response.msg}, '
|
||||
f'log_id={response.get_log_id()}, resp={getattr(getattr(response, "raw", None), "content", None)}'
|
||||
raise Exception(
|
||||
f'client.im.v1.message.patch failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error updating lark card with form buttons: {traceback.format_exc()}')
|
||||
return
|
||||
|
||||
# Send media messages when streaming is done
|
||||
if is_final and media_items:
|
||||
for media in media_items:
|
||||
media_request: ReplyMessageRequest = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(message_source.message_chain.message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(json.dumps(media['content']))
|
||||
.msg_type(media['msg_type'])
|
||||
.reply_in_thread(False)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
media_response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(
|
||||
media_request, req_opt
|
||||
)
|
||||
if not media_response.success():
|
||||
raise Exception(
|
||||
f'client.im.v1.message.reply ({media["msg_type"]}) failed, code: {media_response.code}, msg: {media_response.msg}, log_id: {media_response.get_log_id()}'
|
||||
)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
|
||||
|
||||
import telegram
|
||||
import telegram.ext
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, CallbackQueryHandler, filters
|
||||
from telegram import Update
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, MessageHandler, filters
|
||||
import telegramify_markdown
|
||||
import typing
|
||||
import traceback
|
||||
import json
|
||||
import base64
|
||||
import pydantic
|
||||
|
||||
@@ -189,7 +189,6 @@ class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: telegram.Bot = pydantic.Field(exclude=True)
|
||||
application: telegram.ext.Application = pydantic.Field(exclude=True)
|
||||
ap: typing.Any = pydantic.Field(exclude=True, default=None)
|
||||
|
||||
message_converter: TelegramMessageConverter = TelegramMessageConverter()
|
||||
event_converter: TelegramEventConverter = TelegramEventConverter()
|
||||
@@ -225,102 +224,6 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
telegram_callback,
|
||||
)
|
||||
)
|
||||
|
||||
async def callback_query_handler(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
try:
|
||||
data = json.loads(query.data)
|
||||
if data.get('form_action') or data.get('f'):
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
|
||||
workflow_run_id = data.get('workflow_run_id', '')
|
||||
w_suffix = data.get('w', '')
|
||||
action_id = data.get('action_id') or data.get('a', '')
|
||||
session_key = data.get('session_key') or data.get('s', '')
|
||||
|
||||
if session_key.startswith('group_') or session_key.startswith('g:'):
|
||||
launcher_type = provider_session.LauncherTypes.GROUP
|
||||
launcher_id = (
|
||||
session_key.split(':', 1)[1]
|
||||
if session_key.startswith('g:')
|
||||
else session_key[len('group_') :]
|
||||
)
|
||||
else:
|
||||
launcher_type = provider_session.LauncherTypes.PERSON
|
||||
launcher_id = (
|
||||
session_key.split(':', 1)[1]
|
||||
if session_key.startswith('p:')
|
||||
else session_key[len('person_') :]
|
||||
)
|
||||
|
||||
user_id = str(query.from_user.id)
|
||||
|
||||
# Find bot_uuid and pipeline_uuid
|
||||
bot_uuid = ''
|
||||
pipeline_uuid = None
|
||||
for b in self.ap.platform_mgr.bots:
|
||||
if b.adapter is self:
|
||||
bot_uuid = b.bot_entity.uuid
|
||||
pipeline_uuid = b.bot_entity.use_pipeline_uuid
|
||||
break
|
||||
|
||||
form_action_data = {
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'w_suffix': w_suffix,
|
||||
'action_id': action_id,
|
||||
'user': f'{launcher_type.value}_{launcher_id}',
|
||||
'inputs': {},
|
||||
}
|
||||
|
||||
message_chain = platform_message.MessageChain(
|
||||
[platform_message.Plain(text=f'[Form Action: {action_id}]')]
|
||||
)
|
||||
|
||||
if launcher_type == provider_session.LauncherTypes.GROUP:
|
||||
synthetic_event = platform_events.GroupMessage(
|
||||
sender=platform_entities.GroupMember(
|
||||
id=user_id,
|
||||
member_name='',
|
||||
permission=platform_entities.Permission.Member,
|
||||
group=platform_entities.Group(
|
||||
id=launcher_id,
|
||||
name='',
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
),
|
||||
message_chain=message_chain,
|
||||
source_platform_object=update,
|
||||
)
|
||||
else:
|
||||
synthetic_event = platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=user_id,
|
||||
nickname='',
|
||||
remark='',
|
||||
),
|
||||
message_chain=message_chain,
|
||||
source_platform_object=update,
|
||||
)
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=bot_uuid,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=user_id,
|
||||
message_event=synthetic_event,
|
||||
message_chain=message_chain,
|
||||
adapter=self,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
variables={
|
||||
'_dify_form_action': form_action_data,
|
||||
'_routed_by_rule': True,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in telegram callback query: {traceback.format_exc()}')
|
||||
|
||||
application.add_handler(CallbackQueryHandler(callback_query_handler))
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
@@ -416,19 +319,14 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
update = event.source_platform_object
|
||||
chat_id = update.effective_chat.id
|
||||
chat_type = update.effective_chat.type
|
||||
effective_message = update.effective_message
|
||||
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
|
||||
message_thread_id = update.message.message_thread_id
|
||||
|
||||
if chat_type == 'private':
|
||||
import time as _time
|
||||
|
||||
draft_id = int(_time.time() * 1000)
|
||||
draft_id = int(time.time() * 1000)
|
||||
self.msg_stream_id[message_id] = ('private', draft_id)
|
||||
|
||||
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id, draft_id=draft_id)
|
||||
try:
|
||||
await self.bot.send_message_draft(**args)
|
||||
except (telegram.error.RetryAfter, telegram.error.BadRequest):
|
||||
pass
|
||||
await self.bot.send_message_draft(**args)
|
||||
else:
|
||||
args = self._build_message_args(chat_id, 'Thinking...', message_thread_id)
|
||||
send_msg = await self.bot.send_message(**args)
|
||||
@@ -449,13 +347,12 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
assert isinstance(message_source.source_platform_object, Update)
|
||||
update = message_source.source_platform_object
|
||||
chat_id = update.effective_chat.id
|
||||
effective_message = update.effective_message
|
||||
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
|
||||
message_thread_id = update.message.message_thread_id
|
||||
|
||||
if message_id not in self.msg_stream_id:
|
||||
return
|
||||
|
||||
chat_mode, stream_id = self.msg_stream_id[message_id]
|
||||
chat_mode, draft_id = self.msg_stream_id[message_id]
|
||||
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||
|
||||
if not components or components[0]['type'] != 'text':
|
||||
@@ -464,42 +361,16 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
return
|
||||
|
||||
content = components[0]['text']
|
||||
form_data = getattr(bot_message, '_form_data', None)
|
||||
|
||||
if form_data and is_final:
|
||||
self.msg_stream_id.pop(message_id, None)
|
||||
await self._send_form_action_buttons(message_source, form_data)
|
||||
return
|
||||
|
||||
if chat_mode == 'private':
|
||||
# Streaming via draft (ephemeral preview in the chat input area)
|
||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=stream_id)
|
||||
try:
|
||||
await self.bot.send_message_draft(**args)
|
||||
except telegram.error.BadRequest as exc:
|
||||
if 'Message_too_long' in str(exc):
|
||||
args['text'] = content[:4000] + '\n\n… (truncated)'
|
||||
try:
|
||||
await self.bot.send_message_draft(**args)
|
||||
except telegram.error.RetryAfter:
|
||||
pass
|
||||
else:
|
||||
pass # Ignore other draft errors (cosmetic)
|
||||
args = self._build_message_args(chat_id, content, message_thread_id, draft_id=draft_id)
|
||||
await self.bot.send_message_draft(**args)
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
# Finalise: send the real message, discard the draft
|
||||
args = self._build_message_args(chat_id, content, message_thread_id)
|
||||
try:
|
||||
await self.bot.send_message(**args)
|
||||
except telegram.error.BadRequest as exc:
|
||||
if 'Message_too_long' in str(exc):
|
||||
args['text'] = content[:4000] + '\n\n… (truncated)'
|
||||
await self.bot.send_message(**args)
|
||||
else:
|
||||
raise
|
||||
del args['draft_id']
|
||||
await self.bot.send_message(**args)
|
||||
self.msg_stream_id.pop(message_id)
|
||||
else:
|
||||
# Streaming via edit_message_text (persistent message)
|
||||
stream_id = draft_id
|
||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||
args = {
|
||||
'message_id': stream_id,
|
||||
@@ -508,68 +379,11 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
}
|
||||
if self.config.get('markdown_card', False):
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
try:
|
||||
await self.bot.edit_message_text(**args)
|
||||
except telegram.error.BadRequest as exc:
|
||||
if 'Message_too_long' in str(exc):
|
||||
args['text'] = self._process_markdown(content[:4000] + '\n\n… (truncated)')
|
||||
await self.bot.edit_message_text(**args)
|
||||
else:
|
||||
raise
|
||||
await self.bot.edit_message_text(**args)
|
||||
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
self.msg_stream_id.pop(message_id)
|
||||
|
||||
async def _send_form_action_buttons(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
form_data: dict,
|
||||
):
|
||||
"""Send inline keyboard buttons for Dify human_input_required form actions."""
|
||||
actions = form_data.get('actions', [])
|
||||
node_title = form_data.get('node_title', '')
|
||||
form_content = form_data.get('form_content', '')
|
||||
workflow_run_id = form_data.get('workflow_run_id', '')
|
||||
# Telegram callback_data is capped at 64 bytes, so we identify the
|
||||
# paused workflow by the last 8 chars of workflow_run_id (unique
|
||||
# within a session with overwhelming probability).
|
||||
w_suffix = workflow_run_id[-8:] if workflow_run_id else ''
|
||||
|
||||
if isinstance(message_source, platform_events.GroupMessage):
|
||||
session_key = f'g:{message_source.group.id}'
|
||||
else:
|
||||
session_key = f'p:{message_source.sender.id}'
|
||||
|
||||
keyboard = []
|
||||
for action in actions:
|
||||
action_id = action.get('id', '')
|
||||
action_title = action.get('title', action_id)
|
||||
callback_payload = {'f': 1, 'a': action_id, 's': session_key}
|
||||
if w_suffix:
|
||||
callback_payload['w'] = w_suffix
|
||||
callback_data = json.dumps(callback_payload, separators=(',', ':'))
|
||||
keyboard.append([InlineKeyboardButton(action_title, callback_data=callback_data)])
|
||||
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
|
||||
update = message_source.source_platform_object
|
||||
chat_id = update.effective_chat.id
|
||||
effective_message = update.effective_message
|
||||
message_thread_id = getattr(effective_message, 'message_thread_id', None) if effective_message else None
|
||||
|
||||
text_lines = [f'[{node_title}] Please select an action:']
|
||||
if form_content:
|
||||
text_lines.insert(0, form_content)
|
||||
args = {
|
||||
'chat_id': chat_id,
|
||||
'text': '\n\n'.join(text_lines),
|
||||
'reply_markup': reply_markup,
|
||||
}
|
||||
if message_thread_id:
|
||||
args['message_thread_id'] = message_thread_id
|
||||
|
||||
await self.bot.send_message(**args)
|
||||
|
||||
def get_launcher_id(self, event: platform_events.MessageEvent) -> str | None:
|
||||
if not isinstance(event.source_platform_object, Update):
|
||||
return None
|
||||
|
||||
@@ -2,11 +2,9 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import base64
|
||||
import mimetypes
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
from langbot.pkg.provider import runner
|
||||
@@ -18,102 +16,6 @@ from langbot.libs.dify_service_api.v1 import client, errors
|
||||
import httpx
|
||||
|
||||
|
||||
# Module-level store for paused-workflow form state, keyed by session key
|
||||
# (launcher_type_value + "_" + launcher_id). Each session holds an
|
||||
# insertion-ordered dict of form_token -> form_data, allowing multiple
|
||||
# Dify workflows to be paused simultaneously for the same session.
|
||||
_PENDING_FORMS: dict[str, 'OrderedDict[str, dict[str, typing.Any]]'] = {}
|
||||
_PENDING_FORM_DEFAULT_TTL = 30 * 60 # 30 minutes safety cap
|
||||
|
||||
|
||||
def _session_key_from_query(query: pipeline_query.Query) -> str:
|
||||
return f'{query.session.launcher_type.value}_{query.session.launcher_id}'
|
||||
|
||||
|
||||
def _prune_pending_forms(now: float | None = None) -> None:
|
||||
if now is None:
|
||||
now = time.time()
|
||||
for session_key in list(_PENDING_FORMS.keys()):
|
||||
forms = _PENDING_FORMS[session_key]
|
||||
expired_tokens = [token for token, data in forms.items() if data.get('_expires_at', 0) <= now]
|
||||
for token in expired_tokens:
|
||||
forms.pop(token, None)
|
||||
if not forms:
|
||||
_PENDING_FORMS.pop(session_key, None)
|
||||
|
||||
|
||||
def _set_pending_form(session_key: str, form_data: dict[str, typing.Any]) -> None:
|
||||
_prune_pending_forms()
|
||||
stored = dict(form_data)
|
||||
expiration_time = stored.get('expiration_time')
|
||||
try:
|
||||
expiration_ts = float(expiration_time) if expiration_time is not None else 0.0
|
||||
except (TypeError, ValueError):
|
||||
expiration_ts = 0.0
|
||||
stored['_expires_at'] = expiration_ts or (time.time() + _PENDING_FORM_DEFAULT_TTL)
|
||||
form_token = str(stored.get('form_token') or '')
|
||||
forms = _PENDING_FORMS.setdefault(session_key, OrderedDict())
|
||||
# Re-insert at the end so this becomes the "latest" entry
|
||||
forms.pop(form_token, None)
|
||||
forms[form_token] = stored
|
||||
|
||||
|
||||
def _get_pending_form_by_token(session_key: str, form_token: str) -> dict[str, typing.Any] | None:
|
||||
_prune_pending_forms()
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms or not form_token:
|
||||
return None
|
||||
return forms.get(form_token)
|
||||
|
||||
|
||||
def _get_pending_form_by_w_suffix(session_key: str, w_suffix: str) -> dict[str, typing.Any] | None:
|
||||
"""Look up a pending form whose workflow_run_id ends with the given suffix.
|
||||
|
||||
Used by adapters (e.g. Telegram) whose callback payload is too small to
|
||||
carry the full form_token / workflow_run_id.
|
||||
"""
|
||||
_prune_pending_forms()
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms or not w_suffix:
|
||||
return None
|
||||
for token in reversed(forms):
|
||||
form = forms[token]
|
||||
if str(form.get('workflow_run_id', '')).endswith(w_suffix):
|
||||
return form
|
||||
return None
|
||||
|
||||
|
||||
def _get_latest_pending_form(session_key: str) -> dict[str, typing.Any] | None:
|
||||
_prune_pending_forms()
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms:
|
||||
return None
|
||||
return forms[next(reversed(forms))]
|
||||
|
||||
|
||||
def _iter_pending_forms(session_key: str) -> typing.Iterator[dict[str, typing.Any]]:
|
||||
"""Iterate pending forms for a session, newest-first."""
|
||||
_prune_pending_forms()
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms:
|
||||
return
|
||||
for token in reversed(list(forms.keys())):
|
||||
yield forms[token]
|
||||
|
||||
|
||||
def _clear_pending_form(session_key: str, form_token: str | None = None) -> None:
|
||||
"""Clear one specific pending form (by token) or all forms for the session."""
|
||||
forms = _PENDING_FORMS.get(session_key)
|
||||
if not forms:
|
||||
return
|
||||
if form_token is None:
|
||||
_PENDING_FORMS.pop(session_key, None)
|
||||
return
|
||||
forms.pop(form_token, None)
|
||||
if not forms:
|
||||
_PENDING_FORMS.pop(session_key, None)
|
||||
|
||||
|
||||
@runner.runner_class('dify-service-api')
|
||||
class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
"""Dify Service API 对话请求器"""
|
||||
@@ -433,140 +335,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _submit_workflow_form_blocking(
|
||||
self, form_action: dict
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""Submit human input to resume a paused Dify workflow (non-streaming)."""
|
||||
|
||||
form_token = form_action['form_token']
|
||||
workflow_run_id = form_action['workflow_run_id']
|
||||
user = form_action['user']
|
||||
action_id = form_action.get('action_id', '')
|
||||
inputs = form_action.get('inputs', {})
|
||||
|
||||
async for chunk in self.dify_client.workflow_submit(
|
||||
form_token=form_token,
|
||||
workflow_run_id=workflow_run_id,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
action=action_id,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-workflow-submit-chunk: ' + str(chunk))
|
||||
|
||||
if chunk['event'] == 'workflow_finished':
|
||||
if chunk['data'].get('error'):
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=content,
|
||||
)
|
||||
|
||||
def _resolve_pending_form(self, session_key: str, form_action: dict) -> dict | None:
|
||||
"""Locate the pending form this action targets.
|
||||
|
||||
Tries identifiers in order of specificity: form_token, full
|
||||
workflow_run_id, workflow_run_id suffix (Telegram-style compact id),
|
||||
then falls back to the newest pending form for the session.
|
||||
"""
|
||||
form_token = form_action.get('form_token')
|
||||
if form_token:
|
||||
form = _get_pending_form_by_token(session_key, form_token)
|
||||
if form:
|
||||
return form
|
||||
|
||||
workflow_run_id = form_action.get('workflow_run_id')
|
||||
if workflow_run_id:
|
||||
for form in _iter_pending_forms(session_key):
|
||||
if form.get('workflow_run_id') == workflow_run_id:
|
||||
return form
|
||||
|
||||
w_suffix = form_action.get('w_suffix')
|
||||
if w_suffix:
|
||||
form = _get_pending_form_by_w_suffix(session_key, w_suffix)
|
||||
if form:
|
||||
return form
|
||||
|
||||
return _get_latest_pending_form(session_key)
|
||||
|
||||
def _merge_pending_form_action(self, session_key: str, form_action: dict | None) -> dict | None:
|
||||
"""Backfill resume fields from the matching pending form."""
|
||||
if not form_action:
|
||||
return None
|
||||
|
||||
merged_action = dict(form_action)
|
||||
merged_action.pop('w_suffix', None)
|
||||
pending_form = self._resolve_pending_form(session_key, form_action)
|
||||
if pending_form:
|
||||
merged_action['form_token'] = merged_action.get('form_token') or pending_form.get('form_token', '')
|
||||
merged_action['workflow_run_id'] = merged_action.get('workflow_run_id') or pending_form.get(
|
||||
'workflow_run_id', ''
|
||||
)
|
||||
merged_action.setdefault('inputs', pending_form.get('inputs', {}))
|
||||
merged_action.setdefault('user', pending_form.get('user', ''))
|
||||
merged_action.setdefault('node_title', pending_form.get('node_title', ''))
|
||||
|
||||
# Resolve clicked action's display title from the stored actions list
|
||||
if 'action_title' not in merged_action:
|
||||
clicked_id = merged_action.get('action_id', '')
|
||||
for action in pending_form.get('actions', []):
|
||||
if str(action.get('id', '')) == str(clicked_id):
|
||||
merged_action['action_title'] = action.get('title', clicked_id)
|
||||
break
|
||||
|
||||
return merged_action
|
||||
|
||||
def _match_pending_form_action(self, session_key: str, user_text: str) -> dict | None:
|
||||
"""Match plain text replies against pending Dify form actions.
|
||||
|
||||
Iterates all pending forms newest-first; the first action whose
|
||||
title/id matches the text wins. This means when multiple forms are
|
||||
pending with the same button label, the most recent one resolves.
|
||||
"""
|
||||
normalized_text = user_text.strip().lower()
|
||||
if not normalized_text:
|
||||
return None
|
||||
|
||||
for pending_form in _iter_pending_forms(session_key):
|
||||
for action in pending_form.get('actions', []):
|
||||
titles = {
|
||||
str(action.get('title', '')).strip().lower(),
|
||||
str(action.get('id', '')).strip().lower(),
|
||||
}
|
||||
if normalized_text in titles:
|
||||
return {
|
||||
'form_token': pending_form.get('form_token', ''),
|
||||
'workflow_run_id': pending_form.get('workflow_run_id', ''),
|
||||
'action_id': action.get('id', ''),
|
||||
'action_title': action.get('title', action.get('id', '')),
|
||||
'node_title': pending_form.get('node_title', ''),
|
||||
'inputs': pending_form.get('inputs', {}),
|
||||
'user': pending_form.get('user', ''),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
async def _workflow_messages(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""调用工作流"""
|
||||
|
||||
# Check if this is a form action resume (button click or text match)
|
||||
form_action_raw = query.variables.get('_dify_form_action')
|
||||
session_key = _session_key_from_query(query)
|
||||
|
||||
if form_action_raw:
|
||||
form_action = self._merge_pending_form_action(session_key, form_action_raw)
|
||||
else:
|
||||
form_action = self._match_pending_form_action(session_key, str(query.message_chain))
|
||||
|
||||
if form_action:
|
||||
_clear_pending_form(session_key, form_action.get('form_token') or None)
|
||||
async for msg in self._submit_workflow_form_blocking(form_action):
|
||||
yield msg
|
||||
return
|
||||
|
||||
if not query.session.using_conversation.uuid:
|
||||
query.session.using_conversation.uuid = str(uuid.uuid4())
|
||||
|
||||
@@ -593,7 +366,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
}
|
||||
|
||||
inputs.update(query.variables)
|
||||
human_input_yielded = False
|
||||
|
||||
async for chunk in self.dify_client.workflow_run(
|
||||
inputs=inputs,
|
||||
@@ -605,46 +377,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'workflow_paused':
|
||||
reasons = chunk['data'].get('reasons', [])
|
||||
workflow_run_id = chunk['data'].get('workflow_run_id', '')
|
||||
for reason in reasons:
|
||||
if reason.get('TYPE') == 'human_input_required':
|
||||
form_content = reason.get('form_content', '')
|
||||
actions = reason.get('actions', [])
|
||||
node_title = reason.get('node_title', '')
|
||||
|
||||
_set_pending_form(
|
||||
_session_key_from_query(query),
|
||||
{
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'form_id': reason.get('form_id'),
|
||||
'form_token': reason.get('form_token'),
|
||||
'node_id': reason.get('node_id'),
|
||||
'node_title': node_title,
|
||||
'form_content': form_content,
|
||||
'inputs': reason.get('inputs', {}),
|
||||
'actions': actions,
|
||||
'expiration_time': reason.get('expiration_time'),
|
||||
'user': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
},
|
||||
)
|
||||
|
||||
query.variables['_dify_form_render'] = {
|
||||
'form_content': form_content,
|
||||
'actions': actions,
|
||||
'node_title': node_title,
|
||||
}
|
||||
|
||||
action_lines = '\n'.join(f'- [{a.get("title", a.get("id", ""))}]' for a in actions)
|
||||
display_text = f'[Human Input Required] {node_title}\n{form_content}\n{action_lines}'
|
||||
|
||||
human_input_yielded = True
|
||||
yield provider_message.Message(
|
||||
role='assistant',
|
||||
content=display_text,
|
||||
)
|
||||
|
||||
if chunk['event'] == 'node_started':
|
||||
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
|
||||
continue
|
||||
@@ -667,8 +399,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
yield msg
|
||||
|
||||
elif chunk['event'] == 'workflow_finished':
|
||||
if human_input_yielded:
|
||||
break
|
||||
if chunk['data']['error']:
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
|
||||
@@ -906,153 +636,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _submit_workflow_form(
|
||||
self, form_action: dict
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""Submit human input to resume a paused Dify workflow."""
|
||||
|
||||
form_token = form_action['form_token']
|
||||
workflow_run_id = form_action['workflow_run_id']
|
||||
user = form_action['user']
|
||||
action_id = form_action.get('action_id', '')
|
||||
action_title = form_action.get('action_title', '') or action_id
|
||||
node_title = form_action.get('node_title', '')
|
||||
inputs = form_action.get('inputs', {})
|
||||
|
||||
messsage_idx = 0
|
||||
is_final = False
|
||||
think_start = False
|
||||
think_end = False
|
||||
workflow_contents = ''
|
||||
repause_form_data: dict | None = None
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', {}).get('remove-think')
|
||||
async for chunk in self.dify_client.workflow_submit(
|
||||
form_token=form_token,
|
||||
workflow_run_id=workflow_run_id,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
action=action_id,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-workflow-submit-chunk: ' + str(chunk))
|
||||
|
||||
yield_this_iteration = False
|
||||
|
||||
if chunk['event'] == 'workflow_finished':
|
||||
is_final = True
|
||||
yield_this_iteration = True
|
||||
if chunk['data'].get('error'):
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
|
||||
if chunk['event'] == 'workflow_paused':
|
||||
reasons = chunk['data'].get('reasons', [])
|
||||
new_run_id = chunk['data'].get('workflow_run_id', workflow_run_id)
|
||||
for reason in reasons:
|
||||
if reason.get('TYPE') != 'human_input_required':
|
||||
continue
|
||||
form_content = reason.get('form_content', '')
|
||||
actions = reason.get('actions', [])
|
||||
# Use a distinct name — `node_title` (the just-resolved step)
|
||||
# must keep its value so the resume notice on the previous
|
||||
# card still shows which step the user acted on.
|
||||
paused_node_title = reason.get('node_title', '')
|
||||
raw_inputs = reason.get('inputs', {})
|
||||
|
||||
_set_pending_form(
|
||||
user,
|
||||
{
|
||||
'workflow_run_id': new_run_id,
|
||||
'form_id': reason.get('form_id'),
|
||||
'form_token': reason.get('form_token'),
|
||||
'node_id': reason.get('node_id'),
|
||||
'node_title': paused_node_title,
|
||||
'form_content': form_content,
|
||||
'inputs': raw_inputs if isinstance(raw_inputs, dict) else {},
|
||||
'actions': actions,
|
||||
'expiration_time': reason.get('expiration_time'),
|
||||
'user': user,
|
||||
},
|
||||
)
|
||||
|
||||
repause_form_data = {
|
||||
'form_content': form_content,
|
||||
'actions': actions,
|
||||
'node_title': paused_node_title,
|
||||
'workflow_run_id': new_run_id,
|
||||
'form_token': reason.get('form_token', ''),
|
||||
}
|
||||
# Ensure the final chunk has non-empty content so
|
||||
# ResponseWrapper (which skips empty-content chunks) lets it
|
||||
# propagate to SendResponseBackStage. Use a zero-width space
|
||||
# so neither Lark nor Telegram renders visible noise — the
|
||||
# adapter substitutes its own card text from _form_data.
|
||||
if not workflow_contents:
|
||||
workflow_contents = ''
|
||||
is_final = True
|
||||
yield_this_iteration = True
|
||||
break
|
||||
|
||||
if chunk['event'] == 'text_chunk':
|
||||
messsage_idx += 1
|
||||
if remove_think:
|
||||
if '<think>' in chunk['data']['text'] and not think_start:
|
||||
think_start = True
|
||||
continue
|
||||
if '</think>' in chunk['data']['text'] and not think_end:
|
||||
import re
|
||||
|
||||
content = re.sub(r'^\n</think>', '', chunk['data']['text'])
|
||||
workflow_contents += content
|
||||
think_end = True
|
||||
elif think_end:
|
||||
workflow_contents += chunk['data']['text']
|
||||
if think_start:
|
||||
continue
|
||||
else:
|
||||
workflow_contents += chunk['data']['text']
|
||||
if messsage_idx % 8 == 0:
|
||||
yield_this_iteration = True
|
||||
|
||||
if yield_this_iteration:
|
||||
msg = provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=workflow_contents,
|
||||
is_final=is_final,
|
||||
)
|
||||
msg._resume_from_form = True
|
||||
if action_title:
|
||||
msg._resume_action_title = action_title
|
||||
if node_title:
|
||||
msg._resume_node_title = node_title
|
||||
if is_final and repause_form_data:
|
||||
msg._form_data = repause_form_data
|
||||
msg._open_new_card = True
|
||||
yield msg
|
||||
if is_final:
|
||||
return
|
||||
|
||||
async def _workflow_messages_chunk(
|
||||
self, query: pipeline_query.Query
|
||||
) -> typing.AsyncGenerator[provider_message.MessageChunk, None]:
|
||||
"""调用工作流"""
|
||||
|
||||
# Check if this is a form action resume (button click or text match)
|
||||
form_action_raw = query.variables.get('_dify_form_action')
|
||||
session_key = _session_key_from_query(query)
|
||||
|
||||
if form_action_raw:
|
||||
form_action = self._merge_pending_form_action(session_key, form_action_raw)
|
||||
else:
|
||||
form_action = self._match_pending_form_action(session_key, str(query.message_chain))
|
||||
|
||||
if form_action:
|
||||
_clear_pending_form(session_key, form_action.get('form_token') or None)
|
||||
# Resume paused workflow via submit endpoint
|
||||
async for msg in self._submit_workflow_form(form_action):
|
||||
yield msg
|
||||
return
|
||||
|
||||
if not query.session.using_conversation.uuid:
|
||||
query.session.using_conversation.uuid = str(uuid.uuid4())
|
||||
|
||||
@@ -1084,13 +672,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
think_start = False
|
||||
think_end = False
|
||||
workflow_contents = ''
|
||||
workflow_run_id = ''
|
||||
human_input_yielded = False
|
||||
|
||||
# Saved form data to attach to the final MessageChunk so the adapter
|
||||
# can detect it when is_final=True and render buttons.
|
||||
pending_form_data = None
|
||||
display_text = ''
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
async for chunk in self.dify_client.workflow_run(
|
||||
@@ -1101,62 +682,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
):
|
||||
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
|
||||
if chunk['event'] in ignored_events:
|
||||
if chunk['event'] == 'workflow_started':
|
||||
workflow_run_id = chunk['data'].get('workflow_run_id', '')
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'workflow_paused':
|
||||
reasons = chunk['data'].get('reasons', [])
|
||||
workflow_run_id = chunk['data'].get('workflow_run_id', workflow_run_id)
|
||||
for reason in reasons:
|
||||
if reason.get('TYPE') == 'human_input_required':
|
||||
form_content = reason.get('form_content', '')
|
||||
actions = reason.get('actions', [])
|
||||
node_title = reason.get('node_title', '')
|
||||
|
||||
# Persist form state in module-level store keyed by session
|
||||
raw_inputs = reason.get('inputs', {})
|
||||
_set_pending_form(
|
||||
_session_key_from_query(query),
|
||||
{
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'form_id': reason.get('form_id'),
|
||||
'form_token': reason.get('form_token'),
|
||||
'node_id': reason.get('node_id'),
|
||||
'node_title': node_title,
|
||||
'form_content': form_content,
|
||||
'inputs': raw_inputs if isinstance(raw_inputs, dict) else {},
|
||||
'actions': actions,
|
||||
'expiration_time': reason.get('expiration_time'),
|
||||
'user': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
},
|
||||
)
|
||||
|
||||
# Pass form render metadata to downstream stages
|
||||
query.variables['_dify_form_render'] = {
|
||||
'form_content': form_content,
|
||||
'actions': actions,
|
||||
'node_title': node_title,
|
||||
}
|
||||
|
||||
action_lines = '\n'.join(f'- [{a.get("title", a.get("id", ""))}]' for a in actions)
|
||||
display_text = f'[Human Input Required] {node_title}\n{form_content}\n{action_lines}'
|
||||
workflow_contents += display_text + '\n'
|
||||
|
||||
# Save form data to attach to the final chunk later.
|
||||
# We do NOT yield here — the form content will be sent
|
||||
# as the final MessageChunk (with is_final=True and
|
||||
# _form_data) so the adapter can update the card and
|
||||
# add buttons in one pass.
|
||||
pending_form_data = {
|
||||
'form_content': form_content,
|
||||
'actions': actions,
|
||||
'node_title': node_title,
|
||||
'workflow_run_id': workflow_run_id,
|
||||
'form_token': reason.get('form_token', ''),
|
||||
}
|
||||
human_input_yielded = True
|
||||
|
||||
if chunk['event'] == 'workflow_finished':
|
||||
is_final = True
|
||||
if chunk['data']['error']:
|
||||
@@ -1204,29 +730,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
yield msg
|
||||
|
||||
if messsage_idx % 8 == 0 or is_final:
|
||||
final_content = workflow_contents if workflow_contents.strip() else ''
|
||||
msg = provider_message.MessageChunk(
|
||||
yield provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=final_content,
|
||||
content=workflow_contents,
|
||||
is_final=is_final,
|
||||
)
|
||||
# Attach form data to the final chunk for the adapter
|
||||
if is_final and pending_form_data:
|
||||
msg._form_data = pending_form_data
|
||||
pending_form_data = None
|
||||
yield msg
|
||||
|
||||
# If the stream ended after workflow_paused without a
|
||||
# workflow_finished event, yield a final chunk so the adapter
|
||||
# can update the card and add buttons.
|
||||
if human_input_yielded and not is_final:
|
||||
msg = provider_message.MessageChunk(
|
||||
role='assistant',
|
||||
content=workflow_contents or display_text,
|
||||
is_final=True,
|
||||
)
|
||||
msg._form_data = pending_form_data
|
||||
yield msg
|
||||
|
||||
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[provider_message.Message, None]:
|
||||
"""运行请求"""
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def dedupe_preregistered_groups() -> None:
|
||||
"""Keep API integration route registration isolated across test modules."""
|
||||
from langbot.pkg.api.http.controller import group
|
||||
|
||||
seen: set[tuple[str, str]] = set()
|
||||
unique_groups = []
|
||||
for group_cls in group.preregistered_groups:
|
||||
key = (group_cls.name, group_cls.path)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique_groups.append(group_cls)
|
||||
|
||||
group.preregistered_groups[:] = unique_groups
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def http_controller_cls(mock_circular_import_chain):
|
||||
"""Import HTTPController under each module's circular-import isolation."""
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
dedupe_preregistered_groups()
|
||||
return HTTPController
|
||||
@@ -102,9 +102,11 @@ def fake_bot_app():
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_bot_app, http_controller_cls):
|
||||
async def quart_test_client(fake_bot_app):
|
||||
"""Create Quart test client (module scope to avoid route re-registration)."""
|
||||
controller = http_controller_cls(fake_bot_app)
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
controller = HTTPController(fake_bot_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
|
||||
@@ -101,9 +101,11 @@ def fake_embed_app():
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_embed_app, http_controller_cls):
|
||||
async def quart_test_client(fake_embed_app):
|
||||
"""Create Quart test client (module scope)."""
|
||||
controller = http_controller_cls(fake_embed_app)
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
controller = HTTPController(fake_embed_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
@@ -297,4 +299,4 @@ class TestEmbedFeedbackEndpoint:
|
||||
json={'message_id': 'msg-123', 'feedback_type': 99}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 400
|
||||
@@ -107,9 +107,11 @@ def fake_knowledge_app():
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_knowledge_app, http_controller_cls):
|
||||
async def quart_test_client(fake_knowledge_app):
|
||||
"""Create Quart test client (module scope to avoid route re-registration)."""
|
||||
controller = http_controller_cls(fake_knowledge_app)
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
controller = HTTPController(fake_knowledge_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
|
||||
@@ -113,9 +113,11 @@ def fake_monitoring_app():
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_monitoring_app, http_controller_cls):
|
||||
async def quart_test_client(fake_monitoring_app):
|
||||
"""Create Quart test client (module scope)."""
|
||||
controller = http_controller_cls(fake_monitoring_app)
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
controller = HTTPController(fake_monitoring_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
@@ -327,4 +329,4 @@ class TestMonitoringExportEndpoint:
|
||||
headers={'Authorization': 'Bearer test_token'}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 200
|
||||
@@ -119,9 +119,11 @@ def fake_pipeline_app():
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_pipeline_app, http_controller_cls):
|
||||
async def quart_test_client(fake_pipeline_app):
|
||||
"""Create Quart test client (module scope to avoid route re-registration)."""
|
||||
controller = http_controller_cls(fake_pipeline_app)
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
controller = HTTPController(fake_pipeline_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
|
||||
@@ -116,9 +116,11 @@ def fake_provider_app():
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def quart_test_client(fake_provider_app, http_controller_cls):
|
||||
async def quart_test_client(fake_provider_app):
|
||||
"""Create Quart test client (module scope to avoid route re-registration)."""
|
||||
controller = http_controller_cls(fake_provider_app)
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
controller = HTTPController(fake_provider_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
|
||||
@@ -119,13 +119,15 @@ def fake_api_app():
|
||||
# ============== QUART TEST CLIENT FIXTURE ==============
|
||||
|
||||
@pytest.fixture
|
||||
async def quart_test_client(fake_api_app, http_controller_cls):
|
||||
async def quart_test_client(fake_api_app):
|
||||
"""
|
||||
Create Quart test client with real HTTPController and route registration.
|
||||
|
||||
Requires mock_circular_import_chain fixture to run first (usefixtures).
|
||||
"""
|
||||
controller = http_controller_cls(fake_api_app)
|
||||
from langbot.pkg.api.http.controller.main import HTTPController
|
||||
|
||||
controller = HTTPController(fake_api_app)
|
||||
await controller.initialize()
|
||||
|
||||
client = controller.quart_app.test_client()
|
||||
@@ -342,4 +344,4 @@ class TestRealImports:
|
||||
break
|
||||
|
||||
assert user_group is not None
|
||||
assert user_group.path == '/api/v1/user'
|
||||
assert user_group.path == '/api/v1/user'
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
PoC test for CWE-94: Authenticated RCE via exec() on user-supplied Python code.
|
||||
|
||||
The /api/v1/system/debug/exec endpoint passes raw HTTP body to exec(),
|
||||
allowing arbitrary code execution when debug_mode is True.
|
||||
|
||||
This test verifies that:
|
||||
1. The exec() endpoint is removed from the codebase entirely.
|
||||
2. No route matches /api/v1/system/debug/exec.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import pathlib
|
||||
|
||||
# Resolve project root (one level up from tests/)
|
||||
_PROJECT_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
|
||||
VULN_FILE = (
|
||||
_PROJECT_ROOT
|
||||
/ "src"
|
||||
/ "langbot"
|
||||
/ "pkg"
|
||||
/ "api"
|
||||
/ "http"
|
||||
/ "controller"
|
||||
/ "groups"
|
||||
/ "system.py"
|
||||
)
|
||||
|
||||
|
||||
def test_no_exec_call_in_system_controller():
|
||||
"""Verify there is no exec() call in system.py that takes user input."""
|
||||
with open(VULN_FILE, "r") as f:
|
||||
source = f.read()
|
||||
|
||||
tree = ast.parse(source)
|
||||
|
||||
exec_calls = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Call):
|
||||
func = node.func
|
||||
# Match bare exec() call
|
||||
if isinstance(func, ast.Name) and func.id == "exec":
|
||||
exec_calls.append(node.lineno)
|
||||
|
||||
assert len(exec_calls) == 0, (
|
||||
f"Found exec() call(s) at line(s) {exec_calls} in system.py. "
|
||||
"User-supplied code must never be passed to exec()."
|
||||
)
|
||||
|
||||
|
||||
def test_no_debug_exec_route():
|
||||
"""Verify the /debug/exec route is not registered."""
|
||||
with open(VULN_FILE, "r") as f:
|
||||
source = f.read()
|
||||
|
||||
assert "debug/exec" not in source, (
|
||||
"The /debug/exec route still exists in system.py. "
|
||||
"This endpoint allows arbitrary code execution and must be removed."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_no_exec_call_in_system_controller()
|
||||
test_no_debug_exec_route()
|
||||
print("All tests passed!")
|
||||
@@ -479,47 +479,6 @@ class TestMessageAggregatorMerge:
|
||||
assert "hello" in merged_str
|
||||
assert "world" in merged_str
|
||||
|
||||
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
|
||||
"""Merged PendingMessage should keep routed_by_rule when any input was rule-routed."""
|
||||
aggregator = get_aggregator_module()
|
||||
|
||||
app = make_aggregator_app()
|
||||
agg = aggregator.MessageAggregator(app)
|
||||
|
||||
chain1 = text_chain("first")
|
||||
chain2 = text_chain("second")
|
||||
event = friend_message_event(chain1)
|
||||
adapter = mock_adapter()
|
||||
|
||||
pending1 = aggregator.PendingMessage(
|
||||
bot_uuid='test-bot',
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_event=event,
|
||||
message_chain=chain1,
|
||||
adapter=adapter,
|
||||
pipeline_uuid='test-pipeline-uuid',
|
||||
routed_by_rule=False,
|
||||
)
|
||||
|
||||
pending2 = aggregator.PendingMessage(
|
||||
bot_uuid='test-bot',
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_event=event,
|
||||
message_chain=chain2,
|
||||
adapter=adapter,
|
||||
pipeline_uuid='test-pipeline-uuid',
|
||||
routed_by_rule=True,
|
||||
)
|
||||
|
||||
merged = agg._merge_messages([pending1, pending2])
|
||||
|
||||
assert merged.routed_by_rule is True
|
||||
assert str(merged.message_chain) == 'first\nsecond'
|
||||
|
||||
|
||||
class TestMessageAggregatorFlush:
|
||||
"""Tests for buffer flush behavior."""
|
||||
@@ -635,3 +594,44 @@ class TestMessageAggregatorFlushAll:
|
||||
# Both buffers should be flushed
|
||||
assert len(agg.buffers) == 0
|
||||
assert app.query_pool.add_query.call_count == 2
|
||||
|
||||
|
||||
class TestMessageAggregatorMergeRoutedFlag:
|
||||
"""Tests for preserving routed message state during merge."""
|
||||
|
||||
def test_merge_messages_preserves_routed_by_rule_if_any_input_matches(self):
|
||||
"""Merged PendingMessage keeps routed_by_rule when any input was rule-routed."""
|
||||
aggregator = get_aggregator_module()
|
||||
agg = aggregator.MessageAggregator(ap=None)
|
||||
chain1 = text_chain("first")
|
||||
chain2 = text_chain("second")
|
||||
event = friend_message_event(chain1)
|
||||
adapter = mock_adapter()
|
||||
|
||||
pending1 = aggregator.PendingMessage(
|
||||
bot_uuid='test-bot',
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_event=event,
|
||||
message_chain=chain1,
|
||||
adapter=adapter,
|
||||
pipeline_uuid='test-pipeline',
|
||||
routed_by_rule=False,
|
||||
)
|
||||
pending2 = aggregator.PendingMessage(
|
||||
bot_uuid='test-bot',
|
||||
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||
launcher_id=12345,
|
||||
sender_id=12345,
|
||||
message_event=event,
|
||||
message_chain=chain2,
|
||||
adapter=adapter,
|
||||
pipeline_uuid='test-pipeline',
|
||||
routed_by_rule=True,
|
||||
)
|
||||
|
||||
merged = agg._merge_messages([pending1, pending2])
|
||||
|
||||
assert merged.routed_by_rule is True
|
||||
assert str(merged.message_chain) == 'first\nsecond'
|
||||
|
||||
@@ -119,10 +119,10 @@ class TestContentFilterStageInit:
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
assert {filter_impl.name for filter_impl in stage.filter_chain} == {
|
||||
assert [filter_impl.name for filter_impl in stage.filter_chain] == [
|
||||
'ban-word-filter',
|
||||
'content-ignore',
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class TestPreContentFilter:
|
||||
|
||||
@@ -11,7 +11,7 @@ Tests cover:
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from unittest.mock import Mock
|
||||
from importlib import import_module
|
||||
|
||||
from tests.factories import (
|
||||
@@ -166,6 +166,29 @@ class TestLongTextProcessStageProcess:
|
||||
assert isinstance(components[0], platform_message.Plain)
|
||||
assert components[0].text == 'short response'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_message_chain_continues_without_processing(self):
|
||||
"""Empty response chains should be a no-op for long text processing."""
|
||||
longtext = get_longtext_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = longtext.LongTextProcessStage(app)
|
||||
|
||||
pipeline_config = make_longtext_config(strategy='forward', threshold=1)
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
result = await stage.process(query, 'LongTextProcessStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query is query
|
||||
assert query.resp_message_chain == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_plain_component_skips(self):
|
||||
"""resp_message_chain with non-Plain components should skip processing."""
|
||||
@@ -200,48 +223,6 @@ class TestLongTextProcessStageProcess:
|
||||
assert components[0].text == 'short'
|
||||
assert components[1].url == 'https://example.com/img.png'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_resp_message_chain(self):
|
||||
"""Empty resp_message_chain should be handled gracefully."""
|
||||
longtext = get_longtext_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = longtext.LongTextProcessStage(app)
|
||||
|
||||
pipeline_config = make_longtext_config(strategy='forward')
|
||||
|
||||
await stage.initialize(pipeline_config)
|
||||
|
||||
query = text_query("hello")
|
||||
query.pipeline_config = pipeline_config
|
||||
query.resp_message_chain = []
|
||||
|
||||
result = await stage.process(query, 'LongTextProcessStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query is query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_message_chain_does_not_call_strategy(self):
|
||||
"""Empty response chains should be a no-op for long text processing."""
|
||||
longtext = get_longtext_module()
|
||||
entities = get_entities_module()
|
||||
|
||||
app = FakeApp()
|
||||
stage = longtext.LongTextProcessStage(app)
|
||||
stage.strategy_impl = AsyncMock()
|
||||
|
||||
query = text_query("hello")
|
||||
query.pipeline_config = make_longtext_config(strategy='forward', threshold=1)
|
||||
query.resp_message_chain = []
|
||||
|
||||
result = await stage.process(query, 'LongTextProcessStage')
|
||||
|
||||
assert result.result_type == entities.ResultType.CONTINUE
|
||||
assert result.new_query is query
|
||||
stage.strategy_impl.process.assert_not_called()
|
||||
|
||||
class TestForwardStrategy:
|
||||
"""Tests for ForwardComponentStrategy."""
|
||||
|
||||
|
||||
@@ -223,8 +223,10 @@ def test_token_manager_next_token_empty():
|
||||
"""Test TokenManager.next_token with empty tokens doesn't error."""
|
||||
mgr = token.TokenManager(name='test', tokens=[])
|
||||
|
||||
assert mgr.next_token() is None
|
||||
mgr.next_token()
|
||||
|
||||
assert mgr.get_token() == ''
|
||||
assert mgr.using_token_index == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -434,23 +434,6 @@ class TestRAGRuntimeServiceGetFileStream:
|
||||
|
||||
assert result == b''
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_stream_normalizes_safe_path(self):
|
||||
"""Safe relative paths are normalized before loading."""
|
||||
mock_app = self._create_mock_app()
|
||||
|
||||
mocks = self._make_rag_import_mocks()
|
||||
|
||||
with isolated_sys_modules(mocks):
|
||||
from langbot.pkg.rag.service.runtime import RAGRuntimeService
|
||||
|
||||
service = RAGRuntimeService(mock_app)
|
||||
|
||||
result = await service.get_file_stream('knowledge/./files/doc.pdf')
|
||||
|
||||
assert result == b'file content'
|
||||
mock_app.storage_mgr.storage_provider.load.assert_called_once_with('knowledge/files/doc.pdf')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_stream_path_traversal_blocked(self):
|
||||
"""Path traversal attacks are blocked."""
|
||||
@@ -471,37 +454,6 @@ class TestRAGRuntimeServiceGetFileStream:
|
||||
with pytest.raises(ValueError, match='Invalid storage path'):
|
||||
await service.get_file_stream('knowledge/../../../etc/passwd')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'storage_path',
|
||||
[
|
||||
'',
|
||||
'../secret.txt',
|
||||
'/absolute/path.txt',
|
||||
'..\\secret.txt',
|
||||
'nested\\..\\secret.txt',
|
||||
'%2e%2e/secret.txt',
|
||||
'nested/%2e%2e/secret.txt',
|
||||
'C:\\secret.txt',
|
||||
'safe/\x00file.txt',
|
||||
],
|
||||
)
|
||||
async def test_get_file_stream_rejects_unsafe_paths(self, storage_path: str):
|
||||
"""Unsafe runtime file paths are rejected before storage load."""
|
||||
mock_app = self._create_mock_app()
|
||||
|
||||
mocks = self._make_rag_import_mocks()
|
||||
|
||||
with isolated_sys_modules(mocks):
|
||||
from langbot.pkg.rag.service.runtime import RAGRuntimeService
|
||||
|
||||
service = RAGRuntimeService(mock_app)
|
||||
|
||||
with pytest.raises(ValueError, match='Invalid storage path'):
|
||||
await service.get_file_stream(storage_path)
|
||||
|
||||
mock_app.storage_mgr.storage_provider.load.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_stream_normalizes_path(self):
|
||||
"""Valid paths with .. in filename (not traversal) should work."""
|
||||
@@ -520,3 +472,50 @@ class TestRAGRuntimeServiceGetFileStream:
|
||||
# Let's test a simple valid path
|
||||
await service.get_file_stream('knowledge/files/test.pdf')
|
||||
mock_app.storage_mgr.storage_provider.load.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_stream_normalizes_safe_relative_path(self):
|
||||
"""Safe relative paths are normalized before loading."""
|
||||
mock_app = self._create_mock_app()
|
||||
|
||||
mocks = self._make_rag_import_mocks()
|
||||
|
||||
with isolated_sys_modules(mocks):
|
||||
from langbot.pkg.rag.service.runtime import RAGRuntimeService
|
||||
|
||||
service = RAGRuntimeService(mock_app)
|
||||
|
||||
await service.get_file_stream('knowledge/./files/doc.pdf')
|
||||
|
||||
mock_app.storage_mgr.storage_provider.load.assert_called_once_with('knowledge/files/doc.pdf')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"storage_path",
|
||||
[
|
||||
"",
|
||||
"../secret.txt",
|
||||
"/absolute/path.txt",
|
||||
"..\\secret.txt",
|
||||
"nested\\..\\secret.txt",
|
||||
"%2e%2e/secret.txt",
|
||||
"nested/%2e%2e/secret.txt",
|
||||
"C:\\secret.txt",
|
||||
"safe/\x00file.txt",
|
||||
],
|
||||
)
|
||||
async def test_get_file_stream_rejects_unsafe_paths(self, storage_path):
|
||||
"""Traversal, absolute, encoded, and Windows-style paths are rejected."""
|
||||
mock_app = self._create_mock_app()
|
||||
|
||||
mocks = self._make_rag_import_mocks()
|
||||
|
||||
with isolated_sys_modules(mocks):
|
||||
from langbot.pkg.rag.service.runtime import RAGRuntimeService
|
||||
|
||||
service = RAGRuntimeService(mock_app)
|
||||
|
||||
with pytest.raises(ValueError, match='Invalid storage path'):
|
||||
await service.get_file_stream(storage_path)
|
||||
|
||||
mock_app.storage_mgr.storage_provider.load.assert_not_called()
|
||||
|
||||
@@ -191,18 +191,19 @@ class TestGetFuncSchema:
|
||||
assert result['parameters']['properties']['param_name']['description'] == 'This is the param description.'
|
||||
|
||||
def test_missing_parameter_doc_uses_empty_description(self):
|
||||
"""Undocumented parameters should not break schema generation."""
|
||||
"""Test that undocumented parameters do not crash schema generation."""
|
||||
funcschema = get_funcschema_module()
|
||||
|
||||
def sample_function(documented: str, undocumented: int):
|
||||
"""Sample function.
|
||||
def partially_documented_func(documented: str, undocumented: int):
|
||||
"""Function with one undocumented param.
|
||||
|
||||
Args:
|
||||
documented(str): documented parameter description
|
||||
documented: Documented parameter.
|
||||
"""
|
||||
pass
|
||||
|
||||
result = funcschema.get_func_schema(sample_function)
|
||||
result = funcschema.get_func_schema(partially_documented_func)
|
||||
|
||||
assert result['parameters']['properties']['documented']['description'] == 'documented parameter description'
|
||||
assert result['parameters']['properties']['undocumented']['description'] == ''
|
||||
props = result['parameters']['properties']
|
||||
assert props['documented']['description'] == 'Documented parameter.'
|
||||
assert props['undocumented']['description'] == ''
|
||||
|
||||
@@ -56,29 +56,21 @@ class TestGetQQImageDownloadableUrl:
|
||||
# Fragment is not included in query string parsing
|
||||
assert "http://example.com/image.jpg" in result_url
|
||||
|
||||
def test_https_url(self):
|
||||
"""Parse HTTPS URL and preserve its scheme."""
|
||||
def test_https_url_preserves_scheme(self):
|
||||
"""Parse HTTPS URL without downgrading the scheme."""
|
||||
url = "https://example.com/image.jpg"
|
||||
result_url, query = get_qq_image_downloadable_url(url)
|
||||
|
||||
assert result_url == "https://example.com/image.jpg"
|
||||
assert query == {}
|
||||
|
||||
def test_preserves_qq_https_scheme_and_query(self):
|
||||
"""QQ image URLs keep HTTPS and query parameters."""
|
||||
result_url, query = get_qq_image_downloadable_url(
|
||||
'https://gchat.qpic.cn/gchatpic_new/abc/0?term=2&is_origin=1'
|
||||
)
|
||||
def test_missing_scheme_defaults_to_http(self):
|
||||
"""Parse scheme-less URL with the existing HTTP default."""
|
||||
url = "example.com/image.jpg?param=value"
|
||||
result_url, query = get_qq_image_downloadable_url(url)
|
||||
|
||||
assert result_url == 'https://gchat.qpic.cn/gchatpic_new/abc/0'
|
||||
assert query == {'term': ['2'], 'is_origin': ['1']}
|
||||
|
||||
def test_defaults_missing_scheme_to_http(self):
|
||||
"""Scheme-less image URLs default to HTTP."""
|
||||
result_url, query = get_qq_image_downloadable_url('gchat.qpic.cn/gchatpic_new/abc/0?term=2')
|
||||
|
||||
assert result_url == 'http://gchat.qpic.cn/gchatpic_new/abc/0'
|
||||
assert query == {'term': ['2']}
|
||||
assert result_url == "http://example.com/image.jpg"
|
||||
assert query == {"param": ["value"]}
|
||||
|
||||
|
||||
class TestExtractB64AndFormat:
|
||||
|
||||
@@ -75,61 +75,6 @@ class TestPkgMgr:
|
||||
]
|
||||
mock_pipmain.assert_called_once_with(expected_args)
|
||||
|
||||
def test_install_requirements_defaults_extra_params_to_none(self):
|
||||
"""install_requirements should not use a mutable default for extra_params."""
|
||||
signature = inspect.signature(pkgmgr.install_requirements)
|
||||
|
||||
assert signature.parameters['extra_params'].default is None
|
||||
|
||||
def test_install_requirements_omitted_extra_params_uses_independent_base_commands(self, monkeypatch):
|
||||
"""Omitted extra_params should not share mutable state across calls."""
|
||||
calls = []
|
||||
monkeypatch.setattr(pkgmgr, 'pipmain', calls.append)
|
||||
|
||||
pkgmgr.install_requirements('requirements.txt')
|
||||
pkgmgr.install_requirements('requirements-dev.txt')
|
||||
|
||||
assert calls == [
|
||||
[
|
||||
'install',
|
||||
'-r',
|
||||
'requirements.txt',
|
||||
'-i',
|
||||
'https://pypi.tuna.tsinghua.edu.cn/simple',
|
||||
'--trusted-host',
|
||||
'pypi.tuna.tsinghua.edu.cn',
|
||||
],
|
||||
[
|
||||
'install',
|
||||
'-r',
|
||||
'requirements-dev.txt',
|
||||
'-i',
|
||||
'https://pypi.tuna.tsinghua.edu.cn/simple',
|
||||
'--trusted-host',
|
||||
'pypi.tuna.tsinghua.edu.cn',
|
||||
],
|
||||
]
|
||||
|
||||
def test_install_requirements_preserves_explicit_extra_params(self, monkeypatch):
|
||||
"""Explicit extra_params should be appended to the generated pip command."""
|
||||
calls = []
|
||||
monkeypatch.setattr(pkgmgr, 'pipmain', calls.append)
|
||||
|
||||
pkgmgr.install_requirements('requirements.txt', extra_params=['--no-deps'])
|
||||
|
||||
assert calls == [
|
||||
[
|
||||
'install',
|
||||
'-r',
|
||||
'requirements.txt',
|
||||
'-i',
|
||||
'https://pypi.tuna.tsinghua.edu.cn/simple',
|
||||
'--trusted-host',
|
||||
'pypi.tuna.tsinghua.edu.cn',
|
||||
'--no-deps',
|
||||
]
|
||||
]
|
||||
|
||||
def test_install_requirements_with_extra_params(self):
|
||||
"""install_requirements handles extra params."""
|
||||
with patch('langbot.pkg.utils.pkgmgr.pipmain') as mock_pipmain:
|
||||
@@ -155,3 +100,34 @@ class TestPkgMgr:
|
||||
call_args = mock_pipmain.call_args[0][0]
|
||||
assert '--no-cache-dir' in call_args
|
||||
assert '--verbose' in call_args
|
||||
|
||||
def test_install_requirements_defaults_extra_params_to_none(self):
|
||||
"""install_requirements does not use a mutable list default."""
|
||||
signature = inspect.signature(pkgmgr.install_requirements)
|
||||
|
||||
assert signature.parameters['extra_params'].default is None
|
||||
|
||||
def test_install_requirements_omitted_extra_params_are_isolated(self):
|
||||
"""Repeated calls without extra_params use independent base commands."""
|
||||
with patch('langbot.pkg.utils.pkgmgr.pipmain') as mock_pipmain:
|
||||
pkgmgr.install_requirements('requirements.txt')
|
||||
pkgmgr.install_requirements('requirements-dev.txt')
|
||||
|
||||
assert mock_pipmain.call_args_list[0].args[0] == [
|
||||
'install',
|
||||
'-r',
|
||||
'requirements.txt',
|
||||
'-i',
|
||||
'https://pypi.tuna.tsinghua.edu.cn/simple',
|
||||
'--trusted-host',
|
||||
'pypi.tuna.tsinghua.edu.cn',
|
||||
]
|
||||
assert mock_pipmain.call_args_list[1].args[0] == [
|
||||
'install',
|
||||
'-r',
|
||||
'requirements-dev.txt',
|
||||
'-i',
|
||||
'https://pypi.tuna.tsinghua.edu.cn/simple',
|
||||
'--trusted-host',
|
||||
'pypi.tuna.tsinghua.edu.cn',
|
||||
]
|
||||
|
||||
@@ -87,22 +87,6 @@ class TestGetRunnerCategory:
|
||||
assert get_runner_category("test", "https://example.com") == RunnerCategory.CLOUD
|
||||
assert get_runner_category("test", "https://myserver.example.org") == RunnerCategory.CLOUD
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'runner_url',
|
||||
[
|
||||
'api.dify.ai/v1',
|
||||
'localhost:7860',
|
||||
'https:///v1',
|
||||
'https://',
|
||||
'https://exa mple.com',
|
||||
'http://[::1',
|
||||
'http://localhost:bad',
|
||||
],
|
||||
)
|
||||
def test_invalid_urls_return_unknown(self, runner_url):
|
||||
"""Invalid or incomplete URLs should return UNKNOWN."""
|
||||
assert get_runner_category("test", runner_url) == RunnerCategory.UNKNOWN
|
||||
|
||||
def test_urlparse_exception_returns_unknown(self):
|
||||
"""Exception during URL parsing should return UNKNOWN."""
|
||||
# Test by mocking urlparse to raise an exception
|
||||
@@ -115,36 +99,49 @@ class TestGetRunnerCategory:
|
||||
result = runner.get_runner_category("test", "http://example.com")
|
||||
assert result == RunnerCategory.UNKNOWN
|
||||
|
||||
def test_url_without_scheme_returns_unknown(self):
|
||||
"""URL without scheme should return UNKNOWN."""
|
||||
assert get_runner_category("test", "example.com") == RunnerCategory.UNKNOWN
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'runner_url',
|
||||
"runner_url",
|
||||
[
|
||||
'http://localhost:7860',
|
||||
'http://127.0.0.1:7860',
|
||||
'http://10.0.0.1:7860',
|
||||
'http://172.16.0.1:7860',
|
||||
'http://172.31.255.255:7860',
|
||||
'http://192.168.1.20:7860',
|
||||
'http://[::1]:7860',
|
||||
"api.dify.ai/v1",
|
||||
"localhost:7860",
|
||||
"https:///v1",
|
||||
"https://",
|
||||
"https://exa mple.com",
|
||||
"http://[::1",
|
||||
"http://localhost:bad",
|
||||
],
|
||||
)
|
||||
def test_detects_local_hosts_with_ipaddress(self, runner_url):
|
||||
"""Local hostnames and private IPs should be categorized as LOCAL."""
|
||||
assert get_runner_category('langflow-api', runner_url) == RunnerCategory.LOCAL
|
||||
def test_invalid_urls_return_unknown(self, runner_url):
|
||||
"""Invalid or scheme-less URLs should not default to CLOUD."""
|
||||
assert get_runner_category("test", runner_url) == RunnerCategory.UNKNOWN
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'runner_url',
|
||||
"runner_url",
|
||||
[
|
||||
'http://10.evil.com',
|
||||
'http://192.168.example.com',
|
||||
"http://localhost:7860",
|
||||
"http://127.0.0.1:7860",
|
||||
"http://10.0.0.1:7860",
|
||||
"http://172.16.0.1:7860",
|
||||
"http://172.31.255.255:7860",
|
||||
"http://192.168.1.20:7860",
|
||||
"http://[::1]:7860",
|
||||
],
|
||||
)
|
||||
def test_local_hosts_are_detected_with_ipaddress(self, runner_url):
|
||||
"""Loopback/private IP addresses and localhost should be LOCAL."""
|
||||
assert get_runner_category("test", runner_url) == RunnerCategory.LOCAL
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"runner_url",
|
||||
[
|
||||
"http://10.evil.com",
|
||||
"http://192.168.example.com",
|
||||
],
|
||||
)
|
||||
def test_private_ip_prefix_domains_are_not_local(self, runner_url):
|
||||
"""Domain names that only look like private IP prefixes should not be LOCAL."""
|
||||
assert get_runner_category('langflow-api', runner_url) == RunnerCategory.CLOUD
|
||||
assert get_runner_category("test", runner_url) == RunnerCategory.CLOUD
|
||||
|
||||
|
||||
class TestIsCloudRunner:
|
||||
"""Test is_cloud_runner helper function."""
|
||||
|
||||
@@ -338,9 +338,7 @@ function NavItems({
|
||||
tooltip={config.name}
|
||||
>
|
||||
{config.icon}
|
||||
<span className="cursor-pointer select-none">
|
||||
{config.name}
|
||||
</span>
|
||||
<span>{config.name}</span>
|
||||
</SidebarMenuButton>
|
||||
</SidebarMenuItem>
|
||||
);
|
||||
@@ -730,9 +728,7 @@ function NavItems({
|
||||
}}
|
||||
>
|
||||
{config.icon}
|
||||
<span className="cursor-pointer select-none">
|
||||
{config.name}
|
||||
</span>
|
||||
<span>{config.name}</span>
|
||||
<div className="ml-auto flex items-center gap-0.5 -mr-1">
|
||||
{canCreate &&
|
||||
(isPlugin ? (
|
||||
@@ -1112,7 +1108,7 @@ function PluginPagesNav() {
|
||||
className="select-none"
|
||||
>
|
||||
{pluginIcon}
|
||||
<span className="cursor-pointer">{page.name}</span>
|
||||
<span>{page.name}</span>
|
||||
</SidebarMenuButton>
|
||||
</SidebarMenuItem>
|
||||
);
|
||||
@@ -1132,7 +1128,7 @@ function PluginPagesNav() {
|
||||
className="select-none"
|
||||
>
|
||||
{pluginIcon}
|
||||
<span className="cursor-pointer">{label}</span>
|
||||
<span>{label}</span>
|
||||
<ChevronRight className="ml-auto size-4 transition-transform duration-200 group-data-[state=open]/collapsible:rotate-90" />
|
||||
</SidebarMenuButton>
|
||||
</CollapsibleTrigger>
|
||||
@@ -1148,9 +1144,7 @@ function PluginPagesNav() {
|
||||
onClick={() => navigate(route)}
|
||||
className="select-none"
|
||||
>
|
||||
<span className="cursor-pointer">
|
||||
{page.name}
|
||||
</span>
|
||||
<span>{page.name}</span>
|
||||
</SidebarMenuSubButton>
|
||||
</SidebarMenuSubItem>
|
||||
);
|
||||
|
||||
@@ -295,7 +295,7 @@ export default function ModelsDialog({
|
||||
|
||||
async function handleScanModels(
|
||||
providerUuid: string,
|
||||
modelType?: ModelType,
|
||||
modelType: ModelType,
|
||||
): Promise<ScanModelsResult> {
|
||||
try {
|
||||
const resp = await httpClient.scanProviderModels(providerUuid, modelType);
|
||||
@@ -319,22 +319,15 @@ export default function ModelsDialog({
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
for (const item of models) {
|
||||
const effectiveType = item.model.type || modelType;
|
||||
if (effectiveType === 'llm') {
|
||||
if (modelType === 'llm') {
|
||||
await httpClient.createProviderLLMModel({
|
||||
name: item.model.name,
|
||||
provider_uuid: providerUuid,
|
||||
abilities: item.abilities,
|
||||
extra_args: {},
|
||||
} as never);
|
||||
} else if (effectiveType === 'embedding') {
|
||||
await httpClient.createProviderEmbeddingModel({
|
||||
name: item.model.name,
|
||||
provider_uuid: providerUuid,
|
||||
extra_args: {},
|
||||
} as never);
|
||||
} else {
|
||||
await httpClient.createProviderRerankModel({
|
||||
await httpClient.createProviderEmbeddingModel({
|
||||
name: item.model.name,
|
||||
provider_uuid: providerUuid,
|
||||
extra_args: {},
|
||||
|
||||
@@ -73,13 +73,10 @@ export default function ProviderForm({
|
||||
>([]);
|
||||
|
||||
useEffect(() => {
|
||||
async function init() {
|
||||
await loadRequesters();
|
||||
if (providerId) {
|
||||
await loadProvider(providerId);
|
||||
}
|
||||
loadRequesters();
|
||||
if (providerId) {
|
||||
loadProvider(providerId);
|
||||
}
|
||||
init();
|
||||
}, [providerId]);
|
||||
|
||||
async function loadRequesters() {
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
Wrench,
|
||||
Check,
|
||||
RefreshCw,
|
||||
Search,
|
||||
} from 'lucide-react';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Input } from '@/components/ui/input';
|
||||
@@ -32,8 +33,6 @@ import ExtraArgsEditor from './ExtraArgsEditor';
|
||||
|
||||
interface AddModelPopoverProps {
|
||||
isOpen: boolean;
|
||||
initialMode?: 'manual' | 'scan';
|
||||
trigger?: React.ReactNode;
|
||||
onOpen: () => void;
|
||||
onClose: () => void;
|
||||
onAddModel: (
|
||||
@@ -42,7 +41,7 @@ interface AddModelPopoverProps {
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
) => Promise<void>;
|
||||
onScanModels: (modelType?: ModelType) => Promise<ScanModelsResult>;
|
||||
onScanModels: (modelType: ModelType) => Promise<ScanModelsResult>;
|
||||
onAddScannedModels: (
|
||||
modelType: ModelType,
|
||||
models: SelectedScannedModel[],
|
||||
@@ -61,8 +60,6 @@ interface AddModelPopoverProps {
|
||||
|
||||
export default function AddModelPopover({
|
||||
isOpen,
|
||||
initialMode = 'manual',
|
||||
trigger,
|
||||
onOpen,
|
||||
onClose,
|
||||
onAddModel,
|
||||
@@ -95,7 +92,7 @@ export default function AddModelPopover({
|
||||
const wasOpen = prevIsOpenRef.current;
|
||||
if (isOpen && !wasOpen) {
|
||||
setTab('llm');
|
||||
setMode(initialMode);
|
||||
setMode('manual');
|
||||
setName('');
|
||||
setAbilities([]);
|
||||
setExtraArgs([]);
|
||||
@@ -104,12 +101,8 @@ export default function AddModelPopover({
|
||||
setSelectedScannedModels({});
|
||||
setScanQuery('');
|
||||
onResetTestResult();
|
||||
if (initialMode === 'scan') {
|
||||
handleScan();
|
||||
}
|
||||
}
|
||||
prevIsOpenRef.current = isOpen;
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [isOpen, onResetTestResult]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -129,8 +122,9 @@ export default function AddModelPopover({
|
||||
const handleScan = async () => {
|
||||
setScanLoading(true);
|
||||
try {
|
||||
const result = await onScanModels(trigger ? undefined : tab);
|
||||
const result = await onScanModels(tab);
|
||||
|
||||
// Enrich abilities from debug.response.data (e.g. features.tools.function_calling)
|
||||
const debugData = (
|
||||
result.debug?.response as { data?: Record<string, unknown>[] }
|
||||
)?.data;
|
||||
@@ -149,9 +143,9 @@ export default function AddModelPopover({
|
||||
| undefined;
|
||||
const tools = features?.tools as Record<string, unknown> | undefined;
|
||||
if (tools?.function_calling === true) {
|
||||
const nextAbilities = new Set(model.abilities || []);
|
||||
nextAbilities.add('func_call');
|
||||
model.abilities = [...nextAbilities];
|
||||
const abilities = new Set(model.abilities || []);
|
||||
abilities.add('func_call');
|
||||
model.abilities = [...abilities];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -253,321 +247,305 @@ export default function AddModelPopover({
|
||||
onOpenChange={(open) => (open ? onOpen() : onClose())}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
{trigger || (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 text-xs"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<Plus className="h-3 w-3 mr-1" />
|
||||
{t('models.addModel')}
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 text-xs"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<Plus className="h-3 w-3 mr-1" />
|
||||
{t('models.addModel')}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="w-[min(24rem,calc(100vw-2rem))] max-h-[calc(100vh-8rem)] flex flex-col overflow-hidden"
|
||||
className="w-[min(24rem,calc(100vw-2rem))] max-h-[70vh] overflow-y-auto overscroll-none focus:outline-none focus-visible:outline-none focus-visible:ring-0"
|
||||
style={{
|
||||
maxHeight: 'min(70vh, var(--radix-popover-content-available-height))',
|
||||
}}
|
||||
align="end"
|
||||
side="bottom"
|
||||
side="left"
|
||||
sideOffset={8}
|
||||
collisionPadding={16}
|
||||
onWheel={(e) => e.stopPropagation()}
|
||||
onTouchMove={(e) => e.stopPropagation()}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<Tabs
|
||||
value={tab}
|
||||
onValueChange={(v) => setTab(v as ModelType)}
|
||||
className="flex flex-col min-h-0 flex-1"
|
||||
>
|
||||
<div className="flex-shrink-0">
|
||||
{!(trigger && initialMode === 'scan') && (
|
||||
<TabsList className="grid w-full grid-cols-3">
|
||||
<TabsTrigger value="llm">
|
||||
<MessageSquareText className="h-4 w-4 mr-1" />
|
||||
{t('models.chat')}
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="embedding">
|
||||
<Cpu className="h-4 w-4 mr-1" />
|
||||
{t('models.embedding')}
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="rerank">
|
||||
<ArrowUpDown className="h-4 w-4 mr-1" />
|
||||
{t('models.rerank')}
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
)}
|
||||
</div>
|
||||
<Tabs value={tab} onValueChange={(v) => setTab(v as ModelType)}>
|
||||
<TabsList className="grid w-full grid-cols-3">
|
||||
<TabsTrigger value="llm">
|
||||
<MessageSquareText className="h-4 w-4 mr-1" />
|
||||
{t('models.chat')}
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="embedding">
|
||||
<Cpu className="h-4 w-4 mr-1" />
|
||||
{t('models.embedding')}
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="rerank">
|
||||
<ArrowUpDown className="h-4 w-4 mr-1" />
|
||||
{t('models.rerank')}
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<div className="overflow-y-auto flex-1 min-h-0">
|
||||
<Tabs
|
||||
value={mode}
|
||||
onValueChange={(v) => setMode(v as 'manual' | 'scan')}
|
||||
>
|
||||
{!trigger && (
|
||||
<TabsList className="grid w-full grid-cols-2 mt-3">
|
||||
<TabsTrigger value="manual">
|
||||
{t('models.manualAdd')}
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="scan">{t('models.scanAdd')}</TabsTrigger>
|
||||
</TabsList>
|
||||
)}
|
||||
<Tabs
|
||||
value={mode}
|
||||
onValueChange={(v) => setMode(v as 'manual' | 'scan')}
|
||||
>
|
||||
<TabsList className="grid w-full grid-cols-2 mt-3">
|
||||
<TabsTrigger value="manual">{t('models.manualAdd')}</TabsTrigger>
|
||||
<TabsTrigger value="scan">{t('models.scanAdd')}</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="manual" className="mt-3">
|
||||
<div className="space-y-3">
|
||||
<div className="space-y-2">
|
||||
<Label>{t('models.modelName')}</Label>
|
||||
<Input
|
||||
placeholder={t('models.modelName')}
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{tab === 'llm' && (
|
||||
<div className="space-y-2">
|
||||
<Label>{t('models.abilities')}</Label>
|
||||
<div className="flex gap-4">
|
||||
<div className="flex items-center gap-2">
|
||||
<Checkbox
|
||||
id="add-vision"
|
||||
checked={abilities.includes('vision')}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleAbility('vision', checked as boolean)
|
||||
}
|
||||
/>
|
||||
<Label htmlFor="add-vision" className="text-sm">
|
||||
<Eye className="h-3 w-3 inline mr-1" />
|
||||
{t('models.visionAbility')}
|
||||
</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Checkbox
|
||||
id="add-func-call"
|
||||
checked={abilities.includes('func_call')}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleAbility('func_call', checked as boolean)
|
||||
}
|
||||
/>
|
||||
<Label htmlFor="add-func-call" className="text-sm">
|
||||
<Wrench className="h-3 w-3 inline mr-1" />
|
||||
{t('models.functionCallAbility')}
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<ExtraArgsEditor
|
||||
args={extraArgs}
|
||||
onChange={setExtraArgs}
|
||||
modelType={tab}
|
||||
<TabsContent value="manual" className="mt-3">
|
||||
<div className="space-y-3">
|
||||
<div className="space-y-2">
|
||||
<Label>{t('models.modelName')}</Label>
|
||||
<Input
|
||||
placeholder={t('models.modelName')}
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
/>
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
className="flex-1"
|
||||
size="sm"
|
||||
onClick={handleAdd}
|
||||
disabled={isSubmitting || isTesting}
|
||||
>
|
||||
{isSubmitting ? t('common.saving') : t('common.add')}
|
||||
</Button>
|
||||
<Button
|
||||
className="flex-1"
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={handleTest}
|
||||
disabled={isSubmitting || isTesting}
|
||||
>
|
||||
{isTesting ? (
|
||||
t('common.loading')
|
||||
) : testResult?.success ? (
|
||||
<>
|
||||
<Check className="h-4 w-4 mr-1 text-green-500" />
|
||||
{(testResult.duration / 1000).toFixed(1)}s
|
||||
</>
|
||||
) : (
|
||||
t('common.test')
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="scan" className="space-y-2 mt-0 pt-0">
|
||||
{scanLoading ? (
|
||||
<div className="flex items-center justify-center py-4">
|
||||
<RefreshCw className="h-4 w-4 mr-2 animate-spin text-muted-foreground" />
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{t('models.scanModels')}...
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="space-y-2">
|
||||
<Input
|
||||
placeholder={t('models.searchScannedModels')}
|
||||
value={scanQuery}
|
||||
onChange={(e) => setScanQuery(e.target.value)}
|
||||
disabled={scannedModels.length === 0}
|
||||
/>
|
||||
{selectableModels.length > 0 && (
|
||||
<div className="flex items-center gap-2 pt-1">
|
||||
<Checkbox
|
||||
id="scan-select-all"
|
||||
checked={allSelected}
|
||||
onCheckedChange={toggleSelectAll}
|
||||
/>
|
||||
<Label
|
||||
htmlFor="scan-select-all"
|
||||
className="text-sm font-medium"
|
||||
>
|
||||
{t('models.selectAll')}
|
||||
<span className="text-muted-foreground ml-1">
|
||||
({Object.keys(selectedScannedModels).length}/
|
||||
{selectableModels.length})
|
||||
</span>
|
||||
</Label>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div
|
||||
className="h-64 overflow-y-auto overscroll-contain rounded-md border"
|
||||
onWheel={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div className="p-3 space-y-2">
|
||||
{filteredScannedModels.length === 0 ? (
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{scannedModels.length === 0
|
||||
? t('models.noScannedModels')
|
||||
: t('models.noScannedModelsMatch')}
|
||||
</p>
|
||||
) : (
|
||||
filteredScannedModels.map((model) => {
|
||||
const isSelected = Boolean(
|
||||
selectedScannedModels[model.id],
|
||||
);
|
||||
const selectedAbilities =
|
||||
selectedScannedModels[model.id]?.abilities || [];
|
||||
return (
|
||||
<div
|
||||
key={model.id}
|
||||
className="rounded-md border p-3 space-y-2"
|
||||
>
|
||||
<div className="flex items-start gap-3">
|
||||
<Checkbox
|
||||
checked={isSelected || model.already_added}
|
||||
disabled={model.already_added}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleScannedModel(
|
||||
model,
|
||||
checked as boolean,
|
||||
)
|
||||
}
|
||||
/>
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="text-sm font-medium break-all">
|
||||
{model.name}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.already_added
|
||||
? t('models.alreadyAdded')
|
||||
: model.type === 'llm'
|
||||
? t('models.chat')
|
||||
: model.type === 'embedding'
|
||||
? t('models.embedding')
|
||||
: t('models.rerank')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{model.type === 'llm' &&
|
||||
isSelected &&
|
||||
!model.already_added && (
|
||||
<div className="flex gap-4 pl-7">
|
||||
<div className="flex items-center gap-2">
|
||||
<Checkbox
|
||||
id={`scan-vision-${model.id}`}
|
||||
checked={selectedAbilities.includes(
|
||||
'vision',
|
||||
)}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleScannedModelAbility(
|
||||
model.id,
|
||||
'vision',
|
||||
checked as boolean,
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Label
|
||||
htmlFor={`scan-vision-${model.id}`}
|
||||
className="text-sm"
|
||||
>
|
||||
<Eye className="h-3 w-3 inline mr-1" />
|
||||
{t('models.visionAbility')}
|
||||
</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Checkbox
|
||||
id={`scan-func-${model.id}`}
|
||||
checked={selectedAbilities.includes(
|
||||
'func_call',
|
||||
)}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleScannedModelAbility(
|
||||
model.id,
|
||||
'func_call',
|
||||
checked as boolean,
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Label
|
||||
htmlFor={`scan-func-${model.id}`}
|
||||
className="text-sm"
|
||||
>
|
||||
<Wrench className="h-3 w-3 inline mr-1" />
|
||||
{t('models.functionCallAbility')}
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})
|
||||
)}
|
||||
{tab === 'llm' && (
|
||||
<div className="space-y-2">
|
||||
<Label>{t('models.abilities')}</Label>
|
||||
<div className="flex gap-4">
|
||||
<div className="flex items-center gap-2">
|
||||
<Checkbox
|
||||
id="add-vision"
|
||||
checked={abilities.includes('vision')}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleAbility('vision', checked as boolean)
|
||||
}
|
||||
/>
|
||||
<Label htmlFor="add-vision" className="text-sm">
|
||||
<Eye className="h-3 w-3 inline mr-1" />
|
||||
{t('models.visionAbility')}
|
||||
</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Checkbox
|
||||
id="add-func-call"
|
||||
checked={abilities.includes('func_call')}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleAbility('func_call', checked as boolean)
|
||||
}
|
||||
/>
|
||||
<Label htmlFor="add-func-call" className="text-sm">
|
||||
<Wrench className="h-3 w-3 inline mr-1" />
|
||||
{t('models.functionCallAbility')}
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<ExtraArgsEditor
|
||||
args={extraArgs}
|
||||
onChange={setExtraArgs}
|
||||
modelType={tab}
|
||||
/>
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
className="flex-1"
|
||||
size="sm"
|
||||
onClick={handleAddScanned}
|
||||
disabled={
|
||||
isSubmitting ||
|
||||
scanLoading ||
|
||||
Object.keys(selectedScannedModels).length === 0
|
||||
}
|
||||
onClick={handleAdd}
|
||||
disabled={isSubmitting || isTesting}
|
||||
>
|
||||
{isSubmitting
|
||||
? t('common.saving')
|
||||
: t('models.addSelectedModels')}
|
||||
{isSubmitting ? t('common.saving') : t('common.add')}
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
className="flex-1"
|
||||
size="sm"
|
||||
onClick={handleScan}
|
||||
disabled={scanLoading || isSubmitting}
|
||||
variant="outline"
|
||||
onClick={handleTest}
|
||||
disabled={isSubmitting || isTesting}
|
||||
>
|
||||
<RefreshCw
|
||||
className={`h-3.5 w-3.5 ${scanLoading ? 'animate-spin' : ''}`}
|
||||
/>
|
||||
{isTesting ? (
|
||||
t('common.loading')
|
||||
) : testResult?.success ? (
|
||||
<>
|
||||
<Check className="h-4 w-4 mr-1 text-green-500" />
|
||||
{(testResult.duration / 1000).toFixed(1)}s
|
||||
</>
|
||||
) : (
|
||||
t('common.test')
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="scan" className="space-y-3 mt-3">
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{t('models.scanModelsHint')}
|
||||
</div>
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
className="flex-1"
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={handleScan}
|
||||
disabled={scanLoading || isSubmitting}
|
||||
>
|
||||
{scanLoading ? (
|
||||
<RefreshCw className="h-4 w-4 mr-1 animate-spin" />
|
||||
) : (
|
||||
<Search className="h-4 w-4 mr-1" />
|
||||
)}
|
||||
{t('models.scanModels')}
|
||||
</Button>
|
||||
<Button
|
||||
className="flex-1"
|
||||
size="sm"
|
||||
onClick={handleAddScanned}
|
||||
disabled={
|
||||
isSubmitting ||
|
||||
scanLoading ||
|
||||
Object.keys(selectedScannedModels).length === 0
|
||||
}
|
||||
>
|
||||
{isSubmitting
|
||||
? t('common.saving')
|
||||
: t('models.addSelectedModels')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>{t('models.scannedModels')}</Label>
|
||||
<Input
|
||||
placeholder={t('models.searchScannedModels')}
|
||||
value={scanQuery}
|
||||
onChange={(e) => setScanQuery(e.target.value)}
|
||||
disabled={scannedModels.length === 0}
|
||||
/>
|
||||
{selectableModels.length > 0 && (
|
||||
<div className="flex items-center gap-2 pt-1">
|
||||
<Checkbox
|
||||
id="scan-select-all"
|
||||
checked={allSelected}
|
||||
onCheckedChange={toggleSelectAll}
|
||||
/>
|
||||
<Label
|
||||
htmlFor="scan-select-all"
|
||||
className="text-sm font-medium"
|
||||
>
|
||||
{t('models.selectAll')}
|
||||
<span className="text-muted-foreground ml-1">
|
||||
({Object.keys(selectedScannedModels).length}/
|
||||
{selectableModels.length})
|
||||
</span>
|
||||
</Label>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div
|
||||
className="h-64 overflow-y-auto overscroll-none rounded-md border"
|
||||
onWheel={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div className="p-3 space-y-2">
|
||||
{filteredScannedModels.length === 0 ? (
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{scannedModels.length === 0
|
||||
? t('models.noScannedModels')
|
||||
: t('models.noScannedModelsMatch')}
|
||||
</p>
|
||||
) : (
|
||||
filteredScannedModels.map((model) => {
|
||||
const isSelected = Boolean(
|
||||
selectedScannedModels[model.id],
|
||||
);
|
||||
const selectedAbilities =
|
||||
selectedScannedModels[model.id]?.abilities || [];
|
||||
return (
|
||||
<div
|
||||
key={model.id}
|
||||
className="rounded-md border p-3 space-y-2"
|
||||
>
|
||||
<div className="flex items-start gap-3">
|
||||
<Checkbox
|
||||
checked={isSelected || model.already_added}
|
||||
disabled={model.already_added}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleScannedModel(model, checked as boolean)
|
||||
}
|
||||
/>
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="text-sm font-medium break-all">
|
||||
{model.name}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{model.already_added
|
||||
? t('models.alreadyAdded')
|
||||
: model.type === 'llm'
|
||||
? t('models.chat')
|
||||
: model.type === 'embedding'
|
||||
? t('models.embedding')
|
||||
: t('models.rerank')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{tab === 'llm' &&
|
||||
isSelected &&
|
||||
!model.already_added && (
|
||||
<div className="flex gap-4 pl-7">
|
||||
<div className="flex items-center gap-2">
|
||||
<Checkbox
|
||||
id={`scan-vision-${model.id}`}
|
||||
checked={selectedAbilities.includes(
|
||||
'vision',
|
||||
)}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleScannedModelAbility(
|
||||
model.id,
|
||||
'vision',
|
||||
checked as boolean,
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Label
|
||||
htmlFor={`scan-vision-${model.id}`}
|
||||
className="text-sm"
|
||||
>
|
||||
<Eye className="h-3 w-3 inline mr-1" />
|
||||
{t('models.visionAbility')}
|
||||
</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Checkbox
|
||||
id={`scan-func-${model.id}`}
|
||||
checked={selectedAbilities.includes(
|
||||
'func_call',
|
||||
)}
|
||||
onCheckedChange={(checked) =>
|
||||
toggleScannedModelAbility(
|
||||
model.id,
|
||||
'func_call',
|
||||
checked as boolean,
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Label
|
||||
htmlFor={`scan-func-${model.id}`}
|
||||
className="text-sm"
|
||||
>
|
||||
<Wrench className="h-3 w-3 inline mr-1" />
|
||||
{t('models.functionCallAbility')}
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</Tabs>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
|
||||
@@ -6,7 +6,6 @@ import {
|
||||
Trash2,
|
||||
Settings,
|
||||
LogIn,
|
||||
Radar,
|
||||
} from 'lucide-react';
|
||||
import { httpClient, systemInfo } from '@/app/infra/http/HttpClient';
|
||||
import { ModelProvider } from '@/app/infra/entities/api';
|
||||
@@ -61,7 +60,7 @@ interface ProviderCardProps {
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
) => Promise<void>;
|
||||
onScanModels: (modelType?: ModelType) => Promise<ScanModelsResult>;
|
||||
onScanModels: (modelType: ModelType) => Promise<ScanModelsResult>;
|
||||
onAddScannedModels: (
|
||||
modelType: ModelType,
|
||||
models: SelectedScannedModel[],
|
||||
@@ -131,7 +130,6 @@ export default function ProviderCard({
|
||||
const { t } = useTranslation();
|
||||
const [deleteProviderConfirmOpen, setDeleteProviderConfirmOpen] =
|
||||
useState(false);
|
||||
const [addModelMode, setAddModelMode] = useState<'manual' | 'scan'>('manual');
|
||||
|
||||
const canDelete =
|
||||
!isLangBotModels &&
|
||||
@@ -312,75 +310,19 @@ export default function ProviderCard({
|
||||
<div />
|
||||
)}
|
||||
{!isLangBotModels && (
|
||||
<div className="flex items-center gap-1">
|
||||
<AddModelPopover
|
||||
isOpen={
|
||||
addModelPopoverOpen === provider.uuid &&
|
||||
addModelMode === 'manual'
|
||||
}
|
||||
initialMode="manual"
|
||||
trigger={
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 text-xs"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setAddModelMode('manual');
|
||||
}}
|
||||
>
|
||||
<Plus className="h-3 w-3 mr-1" />
|
||||
{t('models.addModel')}
|
||||
</Button>
|
||||
}
|
||||
onOpen={() => {
|
||||
setAddModelMode('manual');
|
||||
onOpenAddModel();
|
||||
}}
|
||||
onClose={onCloseAddModel}
|
||||
onAddModel={onAddModel}
|
||||
onScanModels={onScanModels}
|
||||
onAddScannedModels={onAddScannedModels}
|
||||
onTestModel={onTestModel}
|
||||
isSubmitting={isSubmitting}
|
||||
isTesting={isTesting}
|
||||
testResult={testResult}
|
||||
onResetTestResult={onResetTestResult}
|
||||
/>
|
||||
<AddModelPopover
|
||||
isOpen={
|
||||
addModelPopoverOpen === provider.uuid &&
|
||||
addModelMode === 'scan'
|
||||
}
|
||||
initialMode="scan"
|
||||
trigger={
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-6 w-6"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setAddModelMode('scan');
|
||||
}}
|
||||
>
|
||||
<Radar className="h-3 w-3" />
|
||||
</Button>
|
||||
}
|
||||
onOpen={() => {
|
||||
setAddModelMode('scan');
|
||||
onOpenAddModel();
|
||||
}}
|
||||
onClose={onCloseAddModel}
|
||||
onAddModel={onAddModel}
|
||||
onScanModels={onScanModels}
|
||||
onAddScannedModels={onAddScannedModels}
|
||||
onTestModel={onTestModel}
|
||||
isSubmitting={isSubmitting}
|
||||
isTesting={isTesting}
|
||||
testResult={testResult}
|
||||
onResetTestResult={onResetTestResult}
|
||||
/>
|
||||
</div>
|
||||
<AddModelPopover
|
||||
isOpen={addModelPopoverOpen === provider.uuid}
|
||||
onOpen={onOpenAddModel}
|
||||
onClose={onCloseAddModel}
|
||||
onAddModel={onAddModel}
|
||||
onScanModels={onScanModels}
|
||||
onAddScannedModels={onAddScannedModels}
|
||||
onTestModel={onTestModel}
|
||||
isSubmitting={isSubmitting}
|
||||
isTesting={isTesting}
|
||||
testResult={testResult}
|
||||
onResetTestResult={onResetTestResult}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</CardHeader>
|
||||
|
||||
@@ -90,7 +90,7 @@ export interface ProviderCardProps {
|
||||
abilities: string[],
|
||||
extraArgs: ExtraArg[],
|
||||
) => Promise<void>;
|
||||
onScanModels: (modelType?: ModelType) => Promise<ScanModelsResult>;
|
||||
onScanModels: (modelType: ModelType) => Promise<ScanModelsResult>;
|
||||
onAddScannedModels: (
|
||||
modelType: ModelType,
|
||||
models: SelectedScannedModel[],
|
||||
|
||||
@@ -4,16 +4,11 @@ import {
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogFooter,
|
||||
} from '@/components/ui/dialog';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
Loader2,
|
||||
RefreshCw,
|
||||
RotateCw,
|
||||
CheckCircle2,
|
||||
XCircle,
|
||||
} from 'lucide-react';
|
||||
import { Loader2, RefreshCw, CheckCircle2, XCircle } from 'lucide-react';
|
||||
import QRCode from 'qrcode';
|
||||
|
||||
export type QrLoginPlatform = 'feishu' | 'weixin' | 'dingtalk' | 'wecombot';
|
||||
@@ -101,7 +96,7 @@ interface QrCodeLoginDialogProps {
|
||||
onSuccess: (credentials: Record<string, string>) => void;
|
||||
}
|
||||
|
||||
type DialogState = 'connecting' | 'waiting' | 'expired' | 'success' | 'error';
|
||||
type DialogState = 'connecting' | 'waiting' | 'success' | 'error';
|
||||
|
||||
const POLL_INTERVAL_MS = 3000;
|
||||
|
||||
@@ -120,10 +115,8 @@ export default function QrCodeLoginDialog({
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
const pollTimerRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const countdownRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const checkExpiredRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const sessionIdRef = useRef<string | null>(null);
|
||||
const baseUrlRef = useRef('');
|
||||
const cleanedRef = useRef(false);
|
||||
|
||||
const onSuccessRef = useRef(onSuccess);
|
||||
@@ -147,14 +140,11 @@ export default function QrCodeLoginDialog({
|
||||
clearInterval(countdownRef.current);
|
||||
countdownRef.current = null;
|
||||
}
|
||||
if (checkExpiredRef.current) {
|
||||
clearInterval(checkExpiredRef.current);
|
||||
checkExpiredRef.current = null;
|
||||
}
|
||||
if (abortRef.current) {
|
||||
abortRef.current.abort();
|
||||
abortRef.current = null;
|
||||
}
|
||||
// Cancel backend session
|
||||
if (sessionIdRef.current) {
|
||||
const token = localStorage.getItem('token');
|
||||
const baseUrl =
|
||||
@@ -181,7 +171,6 @@ export default function QrCodeLoginDialog({
|
||||
|
||||
const token = localStorage.getItem('token');
|
||||
const baseUrl = import.meta.env.VITE_API_BASE_URL || window.location.origin;
|
||||
baseUrlRef.current = baseUrl;
|
||||
const cfg = platformConfigRef.current;
|
||||
|
||||
try {
|
||||
@@ -202,6 +191,8 @@ export default function QrCodeLoginDialog({
|
||||
const { session_id, qr_data_url, qr_url, expire_at } = json.data;
|
||||
sessionIdRef.current = session_id;
|
||||
|
||||
// qr_data_url is a pre-rendered data URL (WeChat);
|
||||
// qr_url is a plain URL string (Feishu) that needs local QR generation.
|
||||
if (qr_data_url) {
|
||||
setQrDataUrl(qr_data_url);
|
||||
} else if (qr_url) {
|
||||
@@ -213,9 +204,11 @@ export default function QrCodeLoginDialog({
|
||||
}
|
||||
setState('waiting');
|
||||
|
||||
// Calculate remaining seconds
|
||||
const remaining = Math.max(0, Math.floor(expire_at - Date.now() / 1000));
|
||||
setExpireIn(remaining);
|
||||
|
||||
// Start countdown
|
||||
countdownRef.current = setInterval(() => {
|
||||
setExpireIn((prev) => {
|
||||
if (prev <= 1) {
|
||||
@@ -229,35 +222,7 @@ export default function QrCodeLoginDialog({
|
||||
});
|
||||
}, 1000);
|
||||
|
||||
// When countdown hits 0, stop polling and show expired state
|
||||
checkExpiredRef.current = setInterval(() => {
|
||||
setExpireIn((current) => {
|
||||
if (current <= 0) {
|
||||
if (checkExpiredRef.current) {
|
||||
clearInterval(checkExpiredRef.current);
|
||||
checkExpiredRef.current = null;
|
||||
}
|
||||
if (pollTimerRef.current) {
|
||||
clearInterval(pollTimerRef.current);
|
||||
pollTimerRef.current = null;
|
||||
}
|
||||
if (sessionIdRef.current) {
|
||||
fetch(
|
||||
`${baseUrlRef.current}${cfg.apiBase}/${sessionIdRef.current}`,
|
||||
{
|
||||
method: 'DELETE',
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
keepalive: true,
|
||||
},
|
||||
).catch(() => {});
|
||||
sessionIdRef.current = null;
|
||||
}
|
||||
setState('expired');
|
||||
}
|
||||
return current;
|
||||
});
|
||||
}, 500);
|
||||
|
||||
// Start polling
|
||||
pollTimerRef.current = setInterval(async () => {
|
||||
try {
|
||||
const pollRes = await fetch(
|
||||
@@ -272,7 +237,7 @@ export default function QrCodeLoginDialog({
|
||||
const { status, error, ...rest } = pollJson.data;
|
||||
|
||||
if (status === 'success') {
|
||||
sessionIdRef.current = null;
|
||||
sessionIdRef.current = null; // backend already cleaned up
|
||||
cleanup();
|
||||
setState('success');
|
||||
setTimeout(() => {
|
||||
@@ -284,14 +249,9 @@ export default function QrCodeLoginDialog({
|
||||
cleanup();
|
||||
setState('error');
|
||||
setErrorMessage(error || tRef.current(cfg.failedKey));
|
||||
} else if (status === 'expired') {
|
||||
sessionIdRef.current = null;
|
||||
cleanup();
|
||||
setExpireIn(0);
|
||||
setState('expired');
|
||||
}
|
||||
} catch {
|
||||
// ignore poll errors
|
||||
// ignore poll errors, will retry next interval
|
||||
}
|
||||
}, POLL_INTERVAL_MS);
|
||||
} catch (err: unknown) {
|
||||
@@ -363,31 +323,6 @@ export default function QrCodeLoginDialog({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* QR code expired — click overlay to refresh */}
|
||||
{state === 'expired' && qrDataUrl && (
|
||||
<div className="flex flex-col items-center space-y-3">
|
||||
<p className="text-sm text-muted-foreground text-center">
|
||||
{t(platformConfig.scanQRCodeKey)}
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
className="relative border rounded-lg p-2 bg-white cursor-pointer group"
|
||||
onClick={() => startLogin()}
|
||||
>
|
||||
<img
|
||||
src={qrDataUrl}
|
||||
alt="QR Code"
|
||||
className="w-56 h-56 opacity-40"
|
||||
/>
|
||||
<div className="absolute inset-0 flex items-center justify-center bg-white/60 rounded-lg group-hover:bg-white/70 transition-colors">
|
||||
<div className="flex items-center justify-center w-16 h-16 rounded-full bg-black/5 group-hover:bg-black/10 transition-colors">
|
||||
<RotateCw className="h-8 w-8 text-muted-foreground" />
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Success */}
|
||||
{state === 'success' && (
|
||||
<div className="flex flex-col items-center space-y-3 py-8">
|
||||
@@ -415,7 +350,7 @@ export default function QrCodeLoginDialog({
|
||||
</div>
|
||||
|
||||
{state === 'error' && (
|
||||
<div className="flex justify-end gap-2">
|
||||
<DialogFooter>
|
||||
<Button variant="outline" onClick={() => handleOpenChange(false)}>
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
@@ -423,7 +358,7 @@ export default function QrCodeLoginDialog({
|
||||
<RefreshCw className="h-4 w-4 mr-1.5" />
|
||||
{t(platformConfig.retryKey)}
|
||||
</Button>
|
||||
</div>
|
||||
</DialogFooter>
|
||||
)}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
Reference in New Issue
Block a user