mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 07:54:19 +00:00
chore: Add PyPI package support for uvx/pip installation (#1764)
* Initial plan * Add package structure and resource path utilities - Created langbot/ package with __init__.py and __main__.py entry point - Added paths utility to find frontend and resource files from package installation - Updated config loading to use resource paths - Updated frontend serving to use resource paths - Added MANIFEST.in for package data inclusion - Updated pyproject.toml with build system and entry points Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> * Add PyPI publishing workflow and update license - Created GitHub Actions workflow to build frontend and publish to PyPI - Added license field to pyproject.toml to fix deprecation warning - Updated .gitignore to exclude build artifacts - Tested package building successfully Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> * Add PyPI installation documentation - Created PYPI_INSTALLATION.md with detailed installation and usage instructions - Updated README.md to feature uvx/pip installation as recommended method - Updated README_EN.md with same changes for English documentation Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> * Address code review feedback - Made package-data configuration more specific to langbot package only - Improved path detection with caching to avoid repeated file I/O - Removed sys.path searching which was incorrect for package data - Removed interactive input() call for non-interactive environment compatibility - Simplified error messages for version check Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> * Fix code review issues - Use specific exception types instead of bare except - Fix misleading comments about directory levels - Remove redundant existence check before makedirs with exist_ok=True - Use context manager for file opening to ensure proper cleanup Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> * Simplify package configuration and document behavioral differences - Removed redundant package-data configuration, relying on MANIFEST.in - Added documentation about behavioral differences between package and source installation - Clarified that include-package-data=true uses MANIFEST.in for data files Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> * chore: update pyproject.toml * chore: try pack templates in langbot/ * chore: update * chore: update * chore: update * chore: update * chore: update * chore: adjust dir structure * chore: fix imports * fix: read default-pipeline-config.json * fix: read default-pipeline-config.json * fix: tests * ci: publish pypi * chore: bump version 4.6.0-beta.1 for testing * chore: add templates/** * fix: send adapters and requesters icons * chore: bump version 4.6.0b2 for testing * chore: add platform field for docker-compose.yaml --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: RockChinQ <45992437+RockChinQ@users.noreply.github.com> Co-authored-by: Junyan Qin <rockchinq@gmail.com>
This commit is contained in:
@@ -0,0 +1,278 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding:utf-8 -*-
|
||||
|
||||
"""对企业微信发送给企业后台的消息加解密示例代码.
|
||||
@copyright: Copyright (c) 1998-2014 Tencent Inc.
|
||||
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
import logging
|
||||
import base64
|
||||
import random
|
||||
import hashlib
|
||||
import time
|
||||
import struct
|
||||
from Crypto.Cipher import AES
|
||||
import xml.etree.cElementTree as ET
|
||||
import socket
|
||||
from langbot.libs.wecom_ai_bot_api import ierror
|
||||
|
||||
|
||||
"""
|
||||
Crypto.Cipher包已不再维护,开发者可以通过以下命令下载安装最新版的加解密工具包
|
||||
pip install pycryptodome
|
||||
"""
|
||||
|
||||
|
||||
class FormatException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def throw_exception(message, exception_class=FormatException):
|
||||
"""my define raise exception function"""
|
||||
raise exception_class(message)
|
||||
|
||||
|
||||
class SHA1:
|
||||
"""计算企业微信的消息签名接口"""
|
||||
|
||||
def getSHA1(self, token, timestamp, nonce, encrypt):
|
||||
"""用SHA1算法生成安全签名
|
||||
@param token: 票据
|
||||
@param timestamp: 时间戳
|
||||
@param encrypt: 密文
|
||||
@param nonce: 随机字符串
|
||||
@return: 安全签名
|
||||
"""
|
||||
try:
|
||||
sortlist = [token, timestamp, nonce, encrypt]
|
||||
sortlist.sort()
|
||||
sha = hashlib.sha1()
|
||||
sha.update(''.join(sortlist).encode())
|
||||
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
|
||||
except Exception as e:
|
||||
logger = logging.getLogger()
|
||||
logger.error(e)
|
||||
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
|
||||
|
||||
|
||||
class XMLParse:
|
||||
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
|
||||
|
||||
# xml消息模板
|
||||
AES_TEXT_RESPONSE_TEMPLATE = """<xml>
|
||||
<Encrypt><![CDATA[%(msg_encrypt)s]]></Encrypt>
|
||||
<MsgSignature><![CDATA[%(msg_signaturet)s]]></MsgSignature>
|
||||
<TimeStamp>%(timestamp)s</TimeStamp>
|
||||
<Nonce><![CDATA[%(nonce)s]]></Nonce>
|
||||
</xml>"""
|
||||
|
||||
def extract(self, xmltext):
|
||||
"""提取出xml数据包中的加密消息
|
||||
@param xmltext: 待提取的xml字符串
|
||||
@return: 提取出的加密消息字符串
|
||||
"""
|
||||
try:
|
||||
xml_tree = ET.fromstring(xmltext)
|
||||
encrypt = xml_tree.find('Encrypt')
|
||||
return ierror.WXBizMsgCrypt_OK, encrypt.text
|
||||
except Exception as e:
|
||||
logger = logging.getLogger()
|
||||
logger.error(e)
|
||||
return ierror.WXBizMsgCrypt_ParseXml_Error, None
|
||||
|
||||
def generate(self, encrypt, signature, timestamp, nonce):
|
||||
"""生成xml消息
|
||||
@param encrypt: 加密后的消息密文
|
||||
@param signature: 安全签名
|
||||
@param timestamp: 时间戳
|
||||
@param nonce: 随机字符串
|
||||
@return: 生成的xml字符串
|
||||
"""
|
||||
resp_dict = {
|
||||
'msg_encrypt': encrypt,
|
||||
'msg_signaturet': signature,
|
||||
'timestamp': timestamp,
|
||||
'nonce': nonce,
|
||||
}
|
||||
resp_xml = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
|
||||
return resp_xml
|
||||
|
||||
|
||||
class PKCS7Encoder:
|
||||
"""提供基于PKCS7算法的加解密接口"""
|
||||
|
||||
block_size = 32
|
||||
|
||||
def encode(self, text):
|
||||
"""对需要加密的明文进行填充补位
|
||||
@param text: 需要进行填充补位操作的明文
|
||||
@return: 补齐明文字符串
|
||||
"""
|
||||
text_length = len(text)
|
||||
# 计算需要填充的位数
|
||||
amount_to_pad = self.block_size - (text_length % self.block_size)
|
||||
if amount_to_pad == 0:
|
||||
amount_to_pad = self.block_size
|
||||
# 获得补位所用的字符
|
||||
pad = chr(amount_to_pad)
|
||||
return text + (pad * amount_to_pad).encode()
|
||||
|
||||
def decode(self, decrypted):
|
||||
"""删除解密后明文的补位字符
|
||||
@param decrypted: 解密后的明文
|
||||
@return: 删除补位字符后的明文
|
||||
"""
|
||||
pad = ord(decrypted[-1])
|
||||
if pad < 1 or pad > 32:
|
||||
pad = 0
|
||||
return decrypted[:-pad]
|
||||
|
||||
|
||||
class Prpcrypt(object):
|
||||
"""提供接收和推送给企业微信消息的加解密接口"""
|
||||
|
||||
def __init__(self, key):
|
||||
# self.key = base64.b64decode(key+"=")
|
||||
self.key = key
|
||||
# 设置加解密模式为AES的CBC模式
|
||||
self.mode = AES.MODE_CBC
|
||||
|
||||
def encrypt(self, text, receiveid):
|
||||
"""对明文进行加密
|
||||
@param text: 需要加密的明文
|
||||
@return: 加密得到的字符串
|
||||
"""
|
||||
# 16位随机字符串添加到明文开头
|
||||
text = text.encode()
|
||||
text = self.get_random_str() + struct.pack('I', socket.htonl(len(text))) + text + receiveid.encode()
|
||||
|
||||
# 使用自定义的填充方式对明文进行补位填充
|
||||
pkcs7 = PKCS7Encoder()
|
||||
text = pkcs7.encode(text)
|
||||
# 加密
|
||||
cryptor = AES.new(self.key, self.mode, self.key[:16])
|
||||
try:
|
||||
ciphertext = cryptor.encrypt(text)
|
||||
# 使用BASE64对加密后的字符串进行编码
|
||||
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
|
||||
except Exception as e:
|
||||
logger = logging.getLogger()
|
||||
logger.error(e)
|
||||
return ierror.WXBizMsgCrypt_EncryptAES_Error, None
|
||||
|
||||
def decrypt(self, text, receiveid):
|
||||
"""对解密后的明文进行补位删除
|
||||
@param text: 密文
|
||||
@return: 删除填充补位后的明文
|
||||
"""
|
||||
try:
|
||||
cryptor = AES.new(self.key, self.mode, self.key[:16])
|
||||
# 使用BASE64对密文进行解码,然后AES-CBC解密
|
||||
plain_text = cryptor.decrypt(base64.b64decode(text))
|
||||
except Exception as e:
|
||||
logger = logging.getLogger()
|
||||
logger.error(e)
|
||||
return ierror.WXBizMsgCrypt_DecryptAES_Error, None
|
||||
try:
|
||||
pad = plain_text[-1]
|
||||
# 去掉补位字符串
|
||||
# pkcs7 = PKCS7Encoder()
|
||||
# plain_text = pkcs7.encode(plain_text)
|
||||
# 去除16位随机字符串
|
||||
content = plain_text[16:-pad]
|
||||
xml_len = socket.ntohl(struct.unpack('I', content[:4])[0])
|
||||
xml_content = content[4 : xml_len + 4]
|
||||
from_receiveid = content[xml_len + 4 :]
|
||||
except Exception as e:
|
||||
logger = logging.getLogger()
|
||||
logger.error(e)
|
||||
return ierror.WXBizMsgCrypt_IllegalBuffer, None
|
||||
|
||||
if from_receiveid.decode('utf8') != receiveid:
|
||||
return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None
|
||||
return 0, xml_content
|
||||
|
||||
def get_random_str(self):
|
||||
"""随机生成16位字符串
|
||||
@return: 16位字符串
|
||||
"""
|
||||
return str(random.randint(1000000000000000, 9999999999999999)).encode()
|
||||
|
||||
|
||||
class WXBizMsgCrypt(object):
|
||||
# 构造函数
|
||||
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
|
||||
try:
|
||||
self.key = base64.b64decode(sEncodingAESKey + '=')
|
||||
assert len(self.key) == 32
|
||||
except Exception:
|
||||
throw_exception('[error]: EncodingAESKey unvalid !', FormatException)
|
||||
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
|
||||
self.m_sToken = sToken
|
||||
self.m_sReceiveId = sReceiveId
|
||||
|
||||
# 验证URL
|
||||
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
|
||||
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
|
||||
# @param sNonce: 随机串,对应URL参数的nonce
|
||||
# @param sEchoStr: 随机串,对应URL参数的echostr
|
||||
# @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效
|
||||
# @return:成功0,失败返回对应的错误码
|
||||
|
||||
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
|
||||
sha1 = SHA1()
|
||||
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
if not signature == sMsgSignature:
|
||||
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
|
||||
pc = Prpcrypt(self.key)
|
||||
ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
|
||||
return ret, sReplyEchoStr
|
||||
|
||||
def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
|
||||
# 将企业回复用户的消息加密打包
|
||||
# @param sReplyMsg: 企业号待回复用户的消息,xml格式的字符串
|
||||
# @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间
|
||||
# @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce
|
||||
# sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的xml格式的字符串,
|
||||
# return:成功0,sEncryptMsg,失败返回对应的错误码None
|
||||
pc = Prpcrypt(self.key)
|
||||
ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
|
||||
encrypt = encrypt.decode('utf8')
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
if timestamp is None:
|
||||
timestamp = str(int(time.time()))
|
||||
# 生成安全签名
|
||||
sha1 = SHA1()
|
||||
ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
xmlParse = XMLParse()
|
||||
return ret, xmlParse.generate(encrypt, signature, timestamp, sNonce)
|
||||
|
||||
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
|
||||
# 检验消息的真实性,并且获取解密后的明文
|
||||
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
|
||||
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
|
||||
# @param sNonce: 随机串,对应URL参数的nonce
|
||||
# @param sPostData: 密文,对应POST请求的数据
|
||||
# xml_content: 解密后的原文,当return返回0时有效
|
||||
# @return: 成功0,失败返回对应的错误码
|
||||
# 验证安全签名
|
||||
xmlParse = XMLParse()
|
||||
ret, encrypt = xmlParse.extract(sPostData)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
sha1 = SHA1()
|
||||
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
if not signature == sMsgSignature:
|
||||
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
|
||||
pc = Prpcrypt(self.key)
|
||||
ret, xml_content = pc.decrypt(encrypt, self.m_sReceiveId)
|
||||
return ret, xml_content
|
||||
@@ -0,0 +1,587 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
from Crypto.Cipher import AES
|
||||
from quart import Quart, request, Response, jsonify
|
||||
|
||||
from langbot.libs.wecom_ai_bot_api import wecombotevent
|
||||
from langbot.libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
from langbot.pkg.platform.logger import EventLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamChunk:
|
||||
"""描述单次推送给企业微信的流式片段。"""
|
||||
|
||||
# 需要返回给企业微信的文本内容
|
||||
content: str
|
||||
|
||||
# 标记是否为最终片段,对应企业微信协议里的 finish 字段
|
||||
is_final: bool = False
|
||||
|
||||
# 预留额外元信息,未来支持多模态扩展时可使用
|
||||
meta: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamSession:
|
||||
"""维护一次企业微信流式会话的上下文。"""
|
||||
|
||||
# 企业微信要求的 stream_id,用于标识后续刷新请求
|
||||
stream_id: str
|
||||
|
||||
# 原始消息的 msgid,便于与流水线消息对应
|
||||
msg_id: str
|
||||
|
||||
# 群聊会话标识(单聊时为空)
|
||||
chat_id: Optional[str]
|
||||
|
||||
# 触发消息的发送者
|
||||
user_id: Optional[str]
|
||||
|
||||
# 会话创建时间
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
# 最近一次被访问的时间,cleanup 依据该值判断过期
|
||||
last_access: float = field(default_factory=time.time)
|
||||
|
||||
# 将流水线增量结果缓存到队列,刷新请求逐条消费
|
||||
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
|
||||
|
||||
# 是否已经完成(收到最终片段)
|
||||
finished: bool = False
|
||||
|
||||
# 缓存最近一次片段,处理重试或超时兜底
|
||||
last_chunk: Optional[StreamChunk] = None
|
||||
|
||||
|
||||
class StreamSessionManager:
|
||||
"""管理 stream 会话的生命周期,并负责队列的生产消费。"""
|
||||
|
||||
def __init__(self, logger: EventLogger, ttl: int = 60) -> None:
|
||||
self.logger = logger
|
||||
|
||||
self.ttl = ttl # 超时时间(秒),超过该时间未被访问的会话会被清理由 cleanup
|
||||
self._sessions: dict[str, StreamSession] = {} # stream_id -> StreamSession 映射
|
||||
self._msg_index: dict[str, str] = {} # msgid -> stream_id 映射,便于流水线根据消息 ID 找到会话
|
||||
|
||||
def get_stream_id_by_msg(self, msg_id: str) -> Optional[str]:
|
||||
if not msg_id:
|
||||
return None
|
||||
return self._msg_index.get(msg_id)
|
||||
|
||||
def get_session(self, stream_id: str) -> Optional[StreamSession]:
|
||||
return self._sessions.get(stream_id)
|
||||
|
||||
def create_or_get(self, msg_json: dict[str, Any]) -> tuple[StreamSession, bool]:
|
||||
"""根据企业微信回调创建或获取会话。
|
||||
|
||||
Args:
|
||||
msg_json: 企业微信解密后的回调 JSON。
|
||||
|
||||
Returns:
|
||||
Tuple[StreamSession, bool]: `StreamSession` 为会话实例,`bool` 指示是否为新建会话。
|
||||
|
||||
Example:
|
||||
在首次回调中调用,得到 `is_new=True` 后再触发流水线。
|
||||
"""
|
||||
msg_id = msg_json.get('msgid', '')
|
||||
if msg_id and msg_id in self._msg_index:
|
||||
stream_id = self._msg_index[msg_id]
|
||||
session = self._sessions.get(stream_id)
|
||||
if session:
|
||||
session.last_access = time.time()
|
||||
return session, False
|
||||
|
||||
stream_id = str(uuid.uuid4())
|
||||
session = StreamSession(
|
||||
stream_id=stream_id,
|
||||
msg_id=msg_id,
|
||||
chat_id=msg_json.get('chatid'),
|
||||
user_id=msg_json.get('from', {}).get('userid'),
|
||||
)
|
||||
|
||||
if msg_id:
|
||||
self._msg_index[msg_id] = stream_id
|
||||
self._sessions[stream_id] = session
|
||||
return session, True
|
||||
|
||||
async def publish(self, stream_id: str, chunk: StreamChunk) -> bool:
|
||||
"""向 stream 队列写入新的增量片段。
|
||||
|
||||
Args:
|
||||
stream_id: 企业微信分配的流式会话 ID。
|
||||
chunk: 待发送的增量片段。
|
||||
|
||||
Returns:
|
||||
bool: 当流式队列存在并成功入队时返回 True。
|
||||
|
||||
Example:
|
||||
在收到模型增量后调用 `await manager.publish('sid', StreamChunk('hello'))`。
|
||||
"""
|
||||
session = self._sessions.get(stream_id)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
session.last_access = time.time()
|
||||
session.last_chunk = chunk
|
||||
|
||||
try:
|
||||
session.queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
# 默认无界队列,此处兜底防御
|
||||
await session.queue.put(chunk)
|
||||
|
||||
if chunk.is_final:
|
||||
session.finished = True
|
||||
|
||||
return True
|
||||
|
||||
async def consume(self, stream_id: str, timeout: float = 0.5) -> Optional[StreamChunk]:
|
||||
"""从队列中取出一个片段,若超时返回 None。
|
||||
|
||||
Args:
|
||||
stream_id: 企业微信流式会话 ID。
|
||||
timeout: 取片段的最长等待时间(秒)。
|
||||
|
||||
Returns:
|
||||
Optional[StreamChunk]: 成功时返回片段,超时或会话不存在时返回 None。
|
||||
|
||||
Example:
|
||||
企业微信刷新到达时调用,若队列有数据则立即返回 `StreamChunk`。
|
||||
"""
|
||||
session = self._sessions.get(stream_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
session.last_access = time.time()
|
||||
|
||||
try:
|
||||
chunk = await asyncio.wait_for(session.queue.get(), timeout)
|
||||
session.last_access = time.time()
|
||||
if chunk.is_final:
|
||||
session.finished = True
|
||||
return chunk
|
||||
except asyncio.TimeoutError:
|
||||
if session.finished and session.last_chunk:
|
||||
return session.last_chunk
|
||||
return None
|
||||
|
||||
def mark_finished(self, stream_id: str) -> None:
|
||||
session = self._sessions.get(stream_id)
|
||||
if session:
|
||||
session.finished = True
|
||||
session.last_access = time.time()
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""定期清理过期会话,防止队列与映射无上限累积。"""
|
||||
now = time.time()
|
||||
expired: list[str] = []
|
||||
for stream_id, session in self._sessions.items():
|
||||
if now - session.last_access > self.ttl:
|
||||
expired.append(stream_id)
|
||||
|
||||
for stream_id in expired:
|
||||
session = self._sessions.pop(stream_id, None)
|
||||
if not session:
|
||||
continue
|
||||
msg_id = session.msg_id
|
||||
if msg_id and self._msg_index.get(msg_id) == stream_id:
|
||||
self._msg_index.pop(msg_id, None)
|
||||
|
||||
|
||||
class WecomBotClient:
|
||||
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger):
|
||||
"""企业微信智能机器人客户端。
|
||||
|
||||
Args:
|
||||
Token: 企业微信回调验证使用的 token。
|
||||
EnCodingAESKey: 企业微信消息加解密密钥。
|
||||
Corpid: 企业 ID。
|
||||
logger: 日志记录器。
|
||||
|
||||
Example:
|
||||
>>> client = WecomBotClient(Token='token', EnCodingAESKey='aeskey', Corpid='corp', logger=logger)
|
||||
"""
|
||||
|
||||
self.Token = Token
|
||||
self.EnCodingAESKey = EnCodingAESKey
|
||||
self.Corpid = Corpid
|
||||
self.ReceiveId = ''
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['POST', 'GET']
|
||||
)
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
self.logger = logger
|
||||
self.generated_content: dict[str, str] = {}
|
||||
self.msg_id_map: dict[str, int] = {}
|
||||
self.stream_sessions = StreamSessionManager(logger=logger)
|
||||
self.stream_poll_timeout = 0.5
|
||||
|
||||
@staticmethod
|
||||
def _build_stream_payload(stream_id: str, content: str, finish: bool) -> dict[str, Any]:
|
||||
"""按照企业微信协议拼装返回报文。
|
||||
|
||||
Args:
|
||||
stream_id: 企业微信会话 ID。
|
||||
content: 推送的文本内容。
|
||||
finish: 是否为最终片段。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 可直接加密返回的 payload。
|
||||
|
||||
Example:
|
||||
组装 `{'msgtype': 'stream', 'stream': {'id': 'sid', ...}}` 结构。
|
||||
"""
|
||||
return {
|
||||
'msgtype': 'stream',
|
||||
'stream': {
|
||||
'id': stream_id,
|
||||
'finish': finish,
|
||||
'content': content,
|
||||
},
|
||||
}
|
||||
|
||||
async def _encrypt_and_reply(self, payload: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
||||
"""对响应进行加密封装并返回给企业微信。
|
||||
|
||||
Args:
|
||||
payload: 待加密的响应内容。
|
||||
nonce: 企业微信回调参数中的 nonce。
|
||||
|
||||
Returns:
|
||||
Tuple[Response, int]: Quart Response 对象及状态码。
|
||||
|
||||
Example:
|
||||
在首包或刷新场景中调用以生成加密响应。
|
||||
"""
|
||||
reply_plain_str = json.dumps(payload, ensure_ascii=False)
|
||||
reply_timestamp = str(int(time.time()))
|
||||
ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
|
||||
if ret != 0:
|
||||
await self.logger.error(f'加密失败: {ret}')
|
||||
return jsonify({'error': 'encrypt_failed'}), 500
|
||||
|
||||
root = ET.fromstring(encrypt_text)
|
||||
encrypt = root.find('Encrypt').text
|
||||
resp = {
|
||||
'encrypt': encrypt,
|
||||
}
|
||||
return jsonify(resp), 200
|
||||
|
||||
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent) -> None:
|
||||
"""异步触发流水线处理,避免阻塞首包响应。
|
||||
|
||||
Args:
|
||||
event: 由企业微信消息转换的内部事件对象。
|
||||
"""
|
||||
try:
|
||||
await self._handle_message(event)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
|
||||
async def _handle_post_initial_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
||||
"""处理企业微信首次推送的消息,返回 stream_id 并开启流水线。
|
||||
|
||||
Args:
|
||||
msg_json: 解密后的企业微信消息 JSON。
|
||||
nonce: 企业微信回调参数 nonce。
|
||||
|
||||
Returns:
|
||||
Tuple[Response, int]: Quart Response 及状态码。
|
||||
|
||||
Example:
|
||||
首次回调时调用,立即返回带 `stream_id` 的响应。
|
||||
"""
|
||||
session, is_new = self.stream_sessions.create_or_get(msg_json)
|
||||
|
||||
message_data = await self.get_message(msg_json)
|
||||
if message_data:
|
||||
message_data['stream_id'] = session.stream_id
|
||||
try:
|
||||
event = wecombotevent.WecomBotEvent(message_data)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
else:
|
||||
if is_new:
|
||||
asyncio.create_task(self._dispatch_event(event))
|
||||
|
||||
payload = self._build_stream_payload(session.stream_id, '', False)
|
||||
return await self._encrypt_and_reply(payload, nonce)
|
||||
|
||||
async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
||||
"""处理企业微信的流式刷新请求,按需返回增量片段。
|
||||
|
||||
Args:
|
||||
msg_json: 解密后的企业微信刷新请求。
|
||||
nonce: 企业微信回调参数 nonce。
|
||||
|
||||
Returns:
|
||||
Tuple[Response, int]: Quart Response 及状态码。
|
||||
|
||||
Example:
|
||||
在刷新请求中调用,按需返回增量片段。
|
||||
"""
|
||||
stream_info = msg_json.get('stream', {})
|
||||
stream_id = stream_info.get('id', '')
|
||||
if not stream_id:
|
||||
await self.logger.error('刷新请求缺少 stream.id')
|
||||
return await self._encrypt_and_reply(self._build_stream_payload('', '', True), nonce)
|
||||
|
||||
session = self.stream_sessions.get_session(stream_id)
|
||||
chunk = await self.stream_sessions.consume(stream_id, timeout=self.stream_poll_timeout)
|
||||
|
||||
if not chunk:
|
||||
cached_content = None
|
||||
if session and session.msg_id:
|
||||
cached_content = self.generated_content.pop(session.msg_id, None)
|
||||
if cached_content is not None:
|
||||
chunk = StreamChunk(content=cached_content, is_final=True)
|
||||
else:
|
||||
payload = self._build_stream_payload(stream_id, '', False)
|
||||
return await self._encrypt_and_reply(payload, nonce)
|
||||
|
||||
payload = self._build_stream_payload(stream_id, chunk.content, chunk.is_final)
|
||||
if chunk.is_final:
|
||||
self.stream_sessions.mark_finished(stream_id)
|
||||
return await self._encrypt_and_reply(payload, nonce)
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""企业微信回调入口。
|
||||
|
||||
Returns:
|
||||
Quart Response: 根据请求类型返回验证、首包或刷新结果。
|
||||
|
||||
Example:
|
||||
作为 Quart 路由处理函数直接注册并使用。
|
||||
"""
|
||||
try:
|
||||
self.wxcpt = WXBizMsgCrypt(self.Token, self.EnCodingAESKey, '')
|
||||
await self.logger.info(f'{request.method} {request.url} {str(request.args)}')
|
||||
|
||||
if request.method == 'GET':
|
||||
return await self._handle_get_callback()
|
||||
|
||||
if request.method == 'POST':
|
||||
return await self._handle_post_callback()
|
||||
|
||||
return Response('', status=405)
|
||||
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
return Response('Internal Server Error', status=500)
|
||||
|
||||
async def _handle_get_callback(self) -> tuple[Response, int] | Response:
|
||||
"""处理企业微信的 GET 验证请求。"""
|
||||
|
||||
msg_signature = unquote(request.args.get('msg_signature', ''))
|
||||
timestamp = unquote(request.args.get('timestamp', ''))
|
||||
nonce = unquote(request.args.get('nonce', ''))
|
||||
echostr = unquote(request.args.get('echostr', ''))
|
||||
|
||||
if not all([msg_signature, timestamp, nonce, echostr]):
|
||||
await self.logger.error('请求参数缺失')
|
||||
return Response('缺少参数', status=400)
|
||||
|
||||
ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
await self.logger.error('验证URL失败')
|
||||
return Response('验证失败', status=403)
|
||||
|
||||
return Response(decrypted_str, mimetype='text/plain')
|
||||
|
||||
async def _handle_post_callback(self) -> tuple[Response, int] | Response:
|
||||
"""处理企业微信的 POST 回调请求。"""
|
||||
|
||||
self.stream_sessions.cleanup()
|
||||
|
||||
msg_signature = unquote(request.args.get('msg_signature', ''))
|
||||
timestamp = unquote(request.args.get('timestamp', ''))
|
||||
nonce = unquote(request.args.get('nonce', ''))
|
||||
|
||||
encrypted_json = await request.get_json()
|
||||
encrypted_msg = (encrypted_json or {}).get('encrypt', '')
|
||||
if not encrypted_msg:
|
||||
await self.logger.error("请求体中缺少 'encrypt' 字段")
|
||||
return Response('Bad Request', status=400)
|
||||
|
||||
xml_post_data = f'<xml><Encrypt><![CDATA[{encrypted_msg}]]></Encrypt></xml>'
|
||||
ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
await self.logger.error('解密失败')
|
||||
return Response('解密失败', status=400)
|
||||
|
||||
msg_json = json.loads(decrypted_xml)
|
||||
|
||||
if msg_json.get('msgtype') == 'stream':
|
||||
return await self._handle_post_followup_response(msg_json, nonce)
|
||||
|
||||
return await self._handle_post_initial_response(msg_json, nonce)
|
||||
|
||||
async def get_message(self, msg_json):
|
||||
message_data = {}
|
||||
|
||||
if msg_json.get('chattype', '') == 'single':
|
||||
message_data['type'] = 'single'
|
||||
elif msg_json.get('chattype', '') == 'group':
|
||||
message_data['type'] = 'group'
|
||||
|
||||
if msg_json.get('msgtype') == 'text':
|
||||
message_data['content'] = msg_json.get('text', {}).get('content')
|
||||
elif msg_json.get('msgtype') == 'image':
|
||||
picurl = msg_json.get('image', {}).get('url', '')
|
||||
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
|
||||
message_data['picurl'] = base64
|
||||
elif msg_json.get('msgtype') == 'mixed':
|
||||
items = msg_json.get('mixed', {}).get('msg_item', [])
|
||||
texts = []
|
||||
picurl = None
|
||||
for item in items:
|
||||
if item.get('msgtype') == 'text':
|
||||
texts.append(item.get('text', {}).get('content', ''))
|
||||
elif item.get('msgtype') == 'image' and picurl is None:
|
||||
picurl = item.get('image', {}).get('url')
|
||||
|
||||
if texts:
|
||||
message_data['content'] = ''.join(texts) # 拼接所有 text
|
||||
if picurl:
|
||||
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
|
||||
message_data['picurl'] = base64 # 只保留第一个 image
|
||||
|
||||
# Extract user information
|
||||
from_info = msg_json.get('from', {})
|
||||
message_data['userid'] = from_info.get('userid', '')
|
||||
message_data['username'] = (
|
||||
from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
||||
)
|
||||
|
||||
# Extract chat/group information
|
||||
if msg_json.get('chattype', '') == 'group':
|
||||
message_data['chatid'] = msg_json.get('chatid', '')
|
||||
# Try to get group name if available
|
||||
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
||||
|
||||
message_data['msgid'] = msg_json.get('msgid', '')
|
||||
|
||||
if msg_json.get('aibotid'):
|
||||
message_data['aibotid'] = msg_json.get('aibotid', '')
|
||||
|
||||
return message_data
|
||||
|
||||
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
|
||||
"""
|
||||
处理消息事件。
|
||||
"""
|
||||
try:
|
||||
message_id = event.message_id
|
||||
if message_id in self.msg_id_map.keys():
|
||||
self.msg_id_map[message_id] += 1
|
||||
return
|
||||
self.msg_id_map[message_id] = 1
|
||||
msg_type = event.type
|
||||
if msg_type in self._message_handlers:
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
|
||||
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
|
||||
"""将流水线片段推送到 stream 会话。
|
||||
|
||||
Args:
|
||||
msg_id: 原始企业微信消息 ID。
|
||||
content: 模型产生的片段内容。
|
||||
is_final: 是否为最终片段。
|
||||
|
||||
Returns:
|
||||
bool: 当成功写入流式队列时返回 True。
|
||||
|
||||
Example:
|
||||
在流水线 `reply_message_chunk` 中调用,将增量推送至企业微信。
|
||||
"""
|
||||
# 根据 msg_id 找到对应 stream 会话,如果不存在说明当前消息非流式
|
||||
stream_id = self.stream_sessions.get_stream_id_by_msg(msg_id)
|
||||
if not stream_id:
|
||||
return False
|
||||
|
||||
chunk = StreamChunk(content=content, is_final=is_final)
|
||||
await self.stream_sessions.publish(stream_id, chunk)
|
||||
if is_final:
|
||||
self.stream_sessions.mark_finished(stream_id)
|
||||
return True
|
||||
|
||||
async def set_message(self, msg_id: str, content: str):
|
||||
"""兼容旧逻辑:若无法流式返回则缓存最终结果。
|
||||
|
||||
Args:
|
||||
msg_id: 企业微信消息 ID。
|
||||
content: 最终回复的文本内容。
|
||||
|
||||
Example:
|
||||
在非流式场景下缓存最终结果以备刷新时返回。
|
||||
"""
|
||||
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
|
||||
if not handled:
|
||||
self.generated_content[msg_id] = content
|
||||
|
||||
def on_message(self, msg_type: str):
|
||||
def decorator(func: Callable[[wecombotevent.WecomBotEvent], None]):
|
||||
if msg_type not in self._message_handlers:
|
||||
self._message_handlers[msg_type] = []
|
||||
self._message_handlers[msg_type].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def download_url_to_base64(self, download_url, encoding_aes_key):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url)
|
||||
if response.status_code != 200:
|
||||
await self.logger.error(f'failed to get file: {response.text}')
|
||||
return None
|
||||
|
||||
encrypted_bytes = response.content
|
||||
|
||||
aes_key = base64.b64decode(encoding_aes_key + '=') # base64 补齐
|
||||
iv = aes_key[:16]
|
||||
|
||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||
decrypted = cipher.decrypt(encrypted_bytes)
|
||||
|
||||
pad_len = decrypted[-1]
|
||||
decrypted = decrypted[:-pad_len]
|
||||
|
||||
if decrypted.startswith(b'\xff\xd8'): # JPEG
|
||||
mime_type = 'image/jpeg'
|
||||
elif decrypted.startswith(b'\x89PNG'): # PNG
|
||||
mime_type = 'image/png'
|
||||
elif decrypted.startswith((b'GIF87a', b'GIF89a')): # GIF
|
||||
mime_type = 'image/gif'
|
||||
elif decrypted.startswith(b'BM'): # BMP
|
||||
mime_type = 'image/bmp'
|
||||
elif decrypted.startswith(b'II*\x00') or decrypted.startswith(b'MM\x00*'): # TIFF
|
||||
mime_type = 'image/tiff'
|
||||
else:
|
||||
mime_type = 'application/octet-stream'
|
||||
|
||||
# 转 base64
|
||||
base64_str = base64.b64encode(decrypted).decode('utf-8')
|
||||
return f'data:{mime_type};base64,{base64_str}'
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
启动 Quart 应用。
|
||||
"""
|
||||
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#########################################################################
|
||||
# Author: jonyqin
|
||||
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
|
||||
# File Name: ierror.py
|
||||
# Description:定义错误码含义
|
||||
#########################################################################
|
||||
WXBizMsgCrypt_OK = 0
|
||||
WXBizMsgCrypt_ValidateSignature_Error = -40001
|
||||
WXBizMsgCrypt_ParseXml_Error = -40002
|
||||
WXBizMsgCrypt_ComputeSignature_Error = -40003
|
||||
WXBizMsgCrypt_IllegalAesKey = -40004
|
||||
WXBizMsgCrypt_ValidateCorpid_Error = -40005
|
||||
WXBizMsgCrypt_EncryptAES_Error = -40006
|
||||
WXBizMsgCrypt_DecryptAES_Error = -40007
|
||||
WXBizMsgCrypt_IllegalBuffer = -40008
|
||||
WXBizMsgCrypt_EncodeBase64_Error = -40009
|
||||
WXBizMsgCrypt_DecodeBase64_Error = -40010
|
||||
WXBizMsgCrypt_GenReturnXml_Error = -40011
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class WecomBotEvent(dict):
|
||||
@staticmethod
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional['WecomBotEvent']:
|
||||
try:
|
||||
event = WecomBotEvent(payload)
|
||||
return event
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""
|
||||
事件类型
|
||||
"""
|
||||
return self.get('type', '')
|
||||
|
||||
@property
|
||||
def userid(self) -> str:
|
||||
"""
|
||||
用户id
|
||||
"""
|
||||
return self.get('from', {}).get('userid', '') or self.get('userid', '')
|
||||
|
||||
@property
|
||||
def username(self) -> str:
|
||||
"""
|
||||
用户名称
|
||||
"""
|
||||
return (
|
||||
self.get('username', '')
|
||||
or self.get('from', {}).get('alias', '')
|
||||
or self.get('from', {}).get('name', '')
|
||||
or self.userid
|
||||
)
|
||||
|
||||
@property
|
||||
def chatname(self) -> str:
|
||||
"""
|
||||
群组名称
|
||||
"""
|
||||
return self.get('chatname', '') or str(self.chatid)
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""
|
||||
内容
|
||||
"""
|
||||
return self.get('content', '')
|
||||
|
||||
@property
|
||||
def picurl(self) -> str:
|
||||
"""
|
||||
图片url
|
||||
"""
|
||||
return self.get('picurl', '')
|
||||
|
||||
@property
|
||||
def chatid(self) -> str:
|
||||
"""
|
||||
群组id
|
||||
"""
|
||||
return self.get('chatid', {})
|
||||
|
||||
@property
|
||||
def message_id(self) -> str:
|
||||
"""
|
||||
消息id
|
||||
"""
|
||||
return self.get('msgid', '')
|
||||
|
||||
@property
|
||||
def ai_bot_id(self) -> str:
|
||||
"""
|
||||
AI Bot ID
|
||||
"""
|
||||
return self.get('aibotid', '')
|
||||
Reference in New Issue
Block a user