style: restrict line-length

This commit is contained in:
Junyan Qin
2025-05-10 18:04:58 +08:00
parent b30016ed08
commit 055b389353
134 changed files with 1096 additions and 2595 deletions

View File

@@ -65,9 +65,7 @@ class RouterGroup(abc.ABC):
async def handler_error(*args, **kwargs):
if auth_type == AuthType.USER_TOKEN:
# 从Authorization头中获取token
token = quart.request.headers.get('Authorization', '').replace(
'Bearer ', ''
)
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')
if not token:
return self.http_status(401, -1, '未提供有效的用户令牌')

View File

@@ -14,10 +14,8 @@ class LogsRouterGroup(group.RouterGroup):
start_page_number = int(quart.request.args.get('start_page_number', 0))
start_offset = int(quart.request.args.get('start_offset', 0))
logs_str, end_page_number, end_offset = (
self.ap.log_cache.get_log_by_pointer(
start_page_number=start_page_number, start_offset=start_offset
)
logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer(
start_page_number=start_page_number, start_offset=start_offset
)
return self.success(

View File

@@ -11,23 +11,17 @@ class PipelinesRouterGroup(group.RouterGroup):
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(
data={'pipelines': await self.ap.pipeline_service.get_pipelines()}
)
return self.success(data={'pipelines': await self.ap.pipeline_service.get_pipelines()})
elif quart.request.method == 'POST':
json_data = await quart.request.json
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(
json_data
)
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(json_data)
return self.success(data={'uuid': pipeline_uuid})
@self.route('/_/metadata', methods=['GET'])
async def _() -> str:
return self.success(
data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()}
)
return self.success(data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()})
@self.route('/<pipeline_uuid>', methods=['GET', 'PUT', 'DELETE'])
async def _(pipeline_uuid: str) -> str:

View File

@@ -8,30 +8,20 @@ class AdaptersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> str:
return self.success(
data={'adapters': self.ap.platform_mgr.get_available_adapters_info()}
)
return self.success(data={'adapters': self.ap.platform_mgr.get_available_adapters_info()})
@self.route('/<adapter_name>', methods=['GET'])
async def _(adapter_name: str) -> str:
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(
adapter_name
)
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(adapter_name)
if adapter_info is None:
return self.http_status(404, -1, 'adapter not found')
return self.success(data={'adapter': adapter_info})
@self.route(
'/<adapter_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE
)
@self.route('/<adapter_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE)
async def _(adapter_name: str) -> quart.Response:
adapter_manifest = (
self.ap.platform_mgr.get_available_adapter_manifest_by_name(
adapter_name
)
)
adapter_manifest = self.ap.platform_mgr.get_available_adapter_manifest_by_name(adapter_name)
if adapter_manifest is None:
return self.http_status(404, -1, 'adapter not found')

View File

@@ -92,9 +92,7 @@ class PluginsRouterGroup(group.RouterGroup):
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
return self.success()
@self.route(
'/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
data = await quart.request.json

View File

@@ -9,9 +9,7 @@ class LLMModelsRouterGroup(group.RouterGroup):
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(
data={'models': await self.ap.model_service.get_llm_models()}
)
return self.success(data={'models': await self.ap.model_service.get_llm_models()})
elif quart.request.method == 'POST':
json_data = await quart.request.json

View File

@@ -8,30 +8,20 @@ class RequestersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> quart.Response:
return self.success(
data={'requesters': self.ap.model_mgr.get_available_requesters_info()}
)
return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()})
@self.route('/<requester_name>', methods=['GET'])
async def _(requester_name: str) -> quart.Response:
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(
requester_name
)
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(requester_name)
if requester_info is None:
return self.http_status(404, -1, 'requester not found')
return self.success(data={'requester': requester_info})
@self.route(
'/<requester_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE
)
@self.route('/<requester_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE)
async def _(requester_name: str) -> quart.Response:
requester_manifest = (
self.ap.model_mgr.get_available_requester_manifest_by_name(
requester_name
)
)
requester_manifest = self.ap.model_mgr.get_available_requester_manifest_by_name(requester_name)
if requester_manifest is None:
return self.http_status(404, -1, 'requester not found')

View File

@@ -8,9 +8,7 @@ class StatsRouterGroup(group.RouterGroup):
async def _() -> str:
conv_count = 0
for session in self.ap.sess_mgr.session_list:
conv_count += len(
session.conversations if session.conversations is not None else []
)
conv_count += len(session.conversations if session.conversations is not None else [])
return self.success(
data={

View File

@@ -13,9 +13,7 @@ class SystemRouterGroup(group.RouterGroup):
data={
'version': constants.semantic_version,
'debug': constants.debug_mode,
'enabled_platform_count': len(
self.ap.platform_mgr.get_running_adapters()
),
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
}
)
@@ -28,9 +26,7 @@ class SystemRouterGroup(group.RouterGroup):
return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type))
@self.route(
'/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
@self.route('/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _(task_id: str) -> str:
task = self.ap.task_mgr.get_task_by_id(int(task_id))
@@ -48,9 +44,7 @@ class SystemRouterGroup(group.RouterGroup):
await self.ap.reload(scope=scope)
return self.success()
@self.route(
'/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')

View File

@@ -10,9 +10,7 @@ class UserRouterGroup(group.RouterGroup):
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
if quart.request.method == 'GET':
return self.success(
data={'initialized': await self.ap.user_service.is_initialized()}
)
return self.success(data={'initialized': await self.ap.user_service.is_initialized()})
if await self.ap.user_service.is_initialized():
return self.fail(1, '系统已初始化')
@@ -31,17 +29,13 @@ class UserRouterGroup(group.RouterGroup):
json_data = await quart.request.json
try:
token = await self.ap.user_service.authenticate(
json_data['user'], json_data['password']
)
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
except argon2.exceptions.VerifyMismatchError:
return self.fail(1, '用户名或密码错误')
return self.success(data={'token': token})
@self.route(
'/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
@self.route('/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str:
token = await self.ap.user_service.generate_jwt_token(user_email)

View File

@@ -70,15 +70,12 @@ class HTTPController:
@self.quart_app.route('/')
async def index():
return await quart.send_from_directory(
frontend_path, 'index.html', mimetype='text/html'
)
return await quart.send_from_directory(frontend_path, 'index.html', mimetype='text/html')
@self.quart_app.route('/<path:path>')
async def static_file(path: str):
if not (
os.path.exists(os.path.join(frontend_path, path))
and os.path.isfile(os.path.join(frontend_path, path))
os.path.exists(os.path.join(frontend_path, path)) and os.path.isfile(os.path.join(frontend_path, path))
):
if os.path.exists(os.path.join(frontend_path, path + '.html')):
path += '.html'
@@ -110,6 +107,4 @@ class HTTPController:
elif path.endswith('.txt'):
mimetype = 'text/plain'
return await quart.send_from_directory(
frontend_path, path, mimetype=mimetype
)
return await quart.send_from_directory(frontend_path, path, mimetype=mimetype)

View File

@@ -18,23 +18,16 @@ class BotService:
async def get_bots(self) -> list[dict]:
"""获取所有机器人"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
bots = result.all()
return [
self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot)
for bot in bots
]
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots]
async def get_bot(self, bot_uuid: str) -> dict | None:
"""获取机器人"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
)
bot = result.first()
@@ -60,9 +53,7 @@ class BotService:
bot_data['use_pipeline_uuid'] = pipeline.uuid
bot_data['use_pipeline_name'] = pipeline.name
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_bot.Bot).values(bot_data)
)
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_bot.Bot).values(bot_data))
bot = await self.get_bot(bot_data['uuid'])
@@ -79,8 +70,7 @@ class BotService:
if 'use_pipeline_uuid' in bot_data:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid
== bot_data['use_pipeline_uuid']
persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']
)
)
pipeline = result.first()
@@ -90,9 +80,7 @@ class BotService:
raise Exception('Pipeline not found')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_bot.Bot)
.values(bot_data)
.where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
)
await self.ap.platform_mgr.remove_bot(bot_uuid)
@@ -108,7 +96,5 @@ class BotService:
"""删除机器人"""
await self.ap.platform_mgr.remove_bot(bot_uuid)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
)

View File

@@ -15,22 +15,15 @@ class ModelsService:
self.ap = ap
async def get_llm_models(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
models = result.all()
return [
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
for model in models
]
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models]
async def create_llm_model(self, model_data: dict) -> str:
model_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)
)
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data))
llm_model = await self.get_llm_model(model_data['uuid'])
@@ -53,9 +46,7 @@ class ModelsService:
async def get_llm_model(self, model_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
)
model = result.first()
@@ -63,9 +54,7 @@ class ModelsService:
if model is None:
return None
return self.ap.persistence_mgr.serialize_model(
persistence_model.LLMModel, model
)
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
if 'uuid' in model_data:
@@ -85,9 +74,7 @@ class ModelsService:
async def delete_llm_model(self, model_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
)
await self.ap.model_mgr.remove_llm_model(model_uuid)

View File

@@ -39,15 +39,11 @@ class PipelineService:
]
async def get_pipelines(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
pipelines = result.all()
return [
self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
for pipeline in pipelines
]
@@ -63,23 +59,17 @@ class PipelineService:
if pipeline is None:
return None
return self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str:
pipeline_data['uuid'] = str(uuid.uuid4())
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
pipeline_data['stages'] = default_stage_order.copy()
pipeline_data['is_default'] = default
pipeline_data['config'] = json.load(
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
)
pipeline_data['config'] = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(
**pipeline_data
)
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data)
)
pipeline = await self.get_pipeline(pipeline_data['uuid'])

View File

@@ -17,9 +17,7 @@ class UserService:
self.ap = ap
async def is_initialized(self) -> bool:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(user.User).limit(1)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(user.User).limit(1))
result_list = result.all()
return result_list is not None and len(result_list) > 0
@@ -30,9 +28,7 @@ class UserService:
hashed_password = ph.hash(password)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(user.User).values(
user=user_email, password=hashed_password
)
sqlalchemy.insert(user.User).values(user=user_email, password=hashed_password)
)
async def get_user_by_email(self, user_email: str) -> user.User | None:
@@ -41,9 +37,7 @@ class UserService:
)
result_list = result.all()
return (
result_list[0] if result_list is not None and len(result_list) > 0 else None
)
return result_list[0] if result_list is not None and len(result_list) > 0 else None
async def authenticate(self, user_email: str, password: str) -> str | None:
result = await self.ap.persistence_mgr.execute_async(

View File

@@ -40,18 +40,14 @@ class CommandManager:
# 应用命令权限配置
for cls in operator.preregistered_operators:
if cls.path in self.ap.instance_config.data['command']['privilege']:
cls.lowest_privilege = self.ap.instance_config.data['command'][
'privilege'
][cls.path]
cls.lowest_privilege = self.ap.instance_config.data['command']['privilege'][cls.path]
# 实例化所有类
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
# 设置所有类的子节点
for cmd in self.cmd_list:
cmd.children = [
child for child in self.cmd_list if child.parent_class == cmd.__class__
]
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
# 初始化所有类
for cmd in self.cmd_list:
@@ -68,10 +64,7 @@ class CommandManager:
found = False
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list:
if (
context.crt_params[0] == oper.name
or context.crt_params[0] in oper.alias
) and (
if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and (
oper.parent_class is None or oper.parent_class == operator.__class__
):
found = True
@@ -85,14 +78,10 @@ class CommandManager:
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
if operator is None:
yield entities.CommandReturn(
error=errors.CommandNotFoundError(context.crt_params[0])
)
yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0]))
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(
error=errors.CommandPrivilegeError(operator.name)
)
yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name))
else:
async for ret in operator.execute(context):
yield ret
@@ -107,10 +96,7 @@ class CommandManager:
privilege = 1
if (
f'{query.launcher_type.value}_{query.launcher_id}'
in self.ap.instance_config.data['admins']
):
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
privilege = 2
ctx = entities.ExecuteContext(

View File

@@ -95,9 +95,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令
支持多次yield以返回多个结果。

View File

@@ -9,9 +9,7 @@ from .. import operator, entities, errors
class CmdOperator(operator.CommandOperator):
"""命令列表"""
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行"""
if len(context.crt_params) == 0:
reply_str = '当前所有命令: \n\n'
@@ -30,16 +28,12 @@ class CmdOperator(operator.CommandOperator):
cmd = None
for _cmd in self.ap.cmd_mgr.cmd_list:
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (
_cmd.parent_class is None
):
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
cmd = _cmd
break
if cmd is None:
yield entities.CommandReturn(
error=errors.CommandNotFoundError(cmd_name)
)
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
else:
reply_str = f'{cmd.name}: {cmd.help}\n\n'
reply_str += f'使用方法: \n{cmd.usage}'

View File

@@ -5,55 +5,38 @@ import typing
from .. import operator, entities, errors
@operator.operator_class(
name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all'
)
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
class DelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
delete_index = 0
if len(context.crt_params) > 0:
try:
delete_index = int(context.crt_params[0])
except Exception:
yield entities.CommandReturn(
error=errors.CommandOperationError('索引必须是整数')
)
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
return
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(
error=errors.CommandOperationError('索引超出范围')
)
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
return
# 倒序
to_delete_index = len(context.session.conversations) - 1 - delete_index
if (
context.session.conversations[to_delete_index]
== context.session.using_conversation
):
if context.session.conversations[to_delete_index] == context.session.using_conversation:
context.session.using_conversation = None
del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
else:
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
@operator.operator_class(
name='all', help='删除此会话的所有历史记录', parent_class=DelOperator
)
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
class DelAllOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
context.session.conversations = []
context.session.using_conversation = None

View File

@@ -6,9 +6,7 @@ from .. import operator, entities
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
class FuncOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = '当前已启用的内容函数: \n\n'
index = 1

View File

@@ -7,9 +7,7 @@ from .. import operator, entities
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
class HelpOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接https://langbot.app'
help += '\n发送命令 !cmd 可查看命令列表'

View File

@@ -8,36 +8,21 @@ from .. import operator, entities, errors
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
class LastOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的上一个会话
for index in range(len(context.session.conversations) - 1, -1, -1):
if (
context.session.conversations[index]
== context.session.using_conversation
):
if context.session.conversations[index] == context.session.using_conversation:
if index == 0:
yield entities.CommandReturn(
error=errors.CommandOperationError('已经是第一个对话了')
)
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
return
else:
context.session.using_conversation = (
context.session.conversations[index - 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
context.session.using_conversation = context.session.conversations[index - 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield entities.CommandReturn(
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
)
return
else:
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -5,22 +5,16 @@ import typing
from .. import operator, entities, errors
@operator.operator_class(
name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>'
)
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
class ListOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
page = 0
if len(context.crt_params) > 0:
try:
page = int(context.crt_params[0] - 1)
except Exception:
yield entities.CommandReturn(
error=errors.CommandOperationError('页码应为整数')
)
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
return
record_per_page = 10
@@ -38,7 +32,9 @@ class ListOperator(operator.CommandOperator):
using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
content += (
f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
)
index += 1
if content == '':

View File

@@ -14,9 +14,7 @@ from .. import operator, entities, errors
class ModelOperator(operator.CommandOperator):
"""Model命令"""
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content = '模型列表:\n'
model_list = self.ap.model_mgr.model_list
@@ -31,15 +29,11 @@ class ModelOperator(operator.CommandOperator):
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator
)
@operator.operator_class(name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator)
class ModelShowOperator(operator.CommandOperator):
"""Model Show命令"""
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
@@ -49,9 +43,7 @@ class ModelShowOperator(operator.CommandOperator):
break
if model is None:
yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
yield entities.CommandReturn(error=errors.CommandError(f'未找到模型 {model_name}'))
else:
content = '模型详情\n'
content += f'名称: {model.name}\n'
@@ -65,15 +57,11 @@ class ModelShowOperator(operator.CommandOperator):
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator
)
@operator.operator_class(name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator)
class ModelSetOperator(operator.CommandOperator):
"""Model Set命令"""
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
@@ -83,12 +71,8 @@ class ModelSetOperator(operator.CommandOperator):
break
if model is None:
yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
yield entities.CommandReturn(error=errors.CommandError(f'未找到模型 {model_name}'))
else:
self.ap.provider_cfg.data['model'] = model_name
await self.ap.provider_cfg.dump_config()
yield entities.CommandReturn(
text=f'已设置当前使用模型为 {model_name},重置会话以生效'
)
yield entities.CommandReturn(text=f'已设置当前使用模型为 {model_name},重置会话以生效')

View File

@@ -7,36 +7,21 @@ from .. import operator, entities, errors
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
class NextOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的下一个会话
for index in range(len(context.session.conversations)):
if (
context.session.conversations[index]
== context.session.using_conversation
):
if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations) - 1:
yield entities.CommandReturn(
error=errors.CommandOperationError('已经是最后一个对话了')
)
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
return
else:
context.session.using_conversation = (
context.session.conversations[index + 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
context.session.using_conversation = context.session.conversations[index + 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield entities.CommandReturn(
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
)
return
else:
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -13,9 +13,7 @@ from .. import operator, entities, errors
usage='!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>',
)
class OllamaOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
content: str = '模型列表:\n'
model_list: list = ollama.list().get('models', [])
@@ -25,9 +23,7 @@ class OllamaOperator(operator.CommandOperator):
content += f'大小: {bytes_to_mb(model["size"])}MB\n\n'
yield entities.CommandReturn(text=f'{content.strip()}')
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
yield entities.CommandReturn(error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常'))
def bytes_to_mb(num_bytes):
@@ -35,13 +31,9 @@ def bytes_to_mb(num_bytes):
return format(mb, '.2f')
@operator.operator_class(
name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator
)
@operator.operator_class(name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator)
class OllamaShowOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content: str = '模型详情:\n'
try:
show: dict = ollama.show(model=context.crt_params[0])
@@ -60,27 +52,19 @@ class OllamaShowOperator(operator.CommandOperator):
content += json.dumps(show, indent=4)
yield entities.CommandReturn(text=content.strip())
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常')
)
yield entities.CommandReturn(error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常'))
@operator.operator_class(
name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator
)
@operator.operator_class(name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator)
class OllamaPullOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
model_list: list = ollama.list().get('models', [])
if context.crt_params[0] in [model['name'] for model in model_list]:
yield entities.CommandReturn(text='模型已存在')
return
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
yield entities.CommandReturn(error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常'))
return
on_progress: bool = False
@@ -108,13 +92,9 @@ class OllamaPullOperator(operator.CommandOperator):
yield entities.CommandReturn(text=f'拉取失败: {e.error}')
@operator.operator_class(
name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator
)
@operator.operator_class(name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator)
class OllamaDelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
ret: str = ollama.delete(model=context.crt_params[0])['status']
except ollama.ResponseError as e:

View File

@@ -11,9 +11,7 @@ from .. import operator, entities, errors
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
)
class PluginOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins()
reply_str = '所有插件({}):\n'.format(len(plugin_list))
idx = 0
@@ -32,17 +30,11 @@ class PluginOperator(operator.CommandOperator):
yield entities.CommandReturn(text=reply_str)
@operator.operator_class(
name='get', help='安装插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
class PluginGetOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件仓库地址')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
else:
repo = context.crt_params[0]
@@ -53,22 +45,14 @@ class PluginGetOperator(operator.CommandOperator):
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件安装失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e)))
@operator.operator_class(
name='update', help='更新插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
class PluginUpdateOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
@@ -78,27 +62,17 @@ class PluginUpdateOperator(operator.CommandOperator):
if plugin_container is not None:
yield entities.CommandReturn(text='正在更新插件...')
await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(
text='插件更新成功,请重启程序以加载插件'
)
yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件')
else:
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: 未找到插件')
)
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件'))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(
name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator
)
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
@@ -111,32 +85,20 @@ class PluginUpdateAllOperator(operator.CommandOperator):
updated.append(plugin_name)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
yield entities.CommandReturn(
text='已更新插件: {}'.format(', '.join(updated))
)
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
else:
yield entities.CommandReturn(text='没有可更新的插件')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(
name='del', help='删除插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
class PluginDelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
@@ -146,79 +108,49 @@ class PluginDelOperator(operator.CommandOperator):
if plugin_container is not None:
yield entities.CommandReturn(text='正在删除插件...')
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(
text='插件删除成功,请重启程序以加载插件'
)
yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件')
else:
yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: 未找到插件')
)
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件'))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e)))
@operator.operator_class(
name='on', help='启用插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
class PluginEnableOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield entities.CommandReturn(
text='已启用插件: {}'.format(plugin_name)
)
yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name))
else:
yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
@operator.operator_class(
name='off', help='禁用插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
class PluginDisableOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield entities.CommandReturn(
text='已禁用插件: {}'.format(plugin_name)
)
yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
else:
yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))

View File

@@ -7,14 +7,10 @@ from .. import operator, entities, errors
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
class PromptOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行"""
if context.session.using_conversation is None:
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
else:
reply_str = '当前对话所有内容:\n\n'

View File

@@ -5,13 +5,9 @@ import typing
from .. import operator, entities, errors
@operator.operator_class(
name='resend', help='重发当前会话的最后一条消息', usage='!resend'
)
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
class ResendOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
# 回滚到最后一条用户message前
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))

View File

@@ -7,9 +7,7 @@ from .. import operator, entities
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
class ResetOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行"""
context.session.using_conversation = None

View File

@@ -8,9 +8,7 @@ from .. import operator, entities, errors
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
class UpdateCommand(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
yield entities.CommandReturn(text='正在进行更新...')
if await self.ap.ver_mgr.update_all():
@@ -19,6 +17,4 @@ class UpdateCommand(operator.CommandOperator):
yield entities.CommandReturn(text='当前已是最新版本')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('更新失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('更新失败: ' + str(e)))

View File

@@ -7,9 +7,7 @@ from .. import operator, entities
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
class VersionCommand(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
try:

View File

@@ -41,9 +41,7 @@ class ConfigManager:
self.file.save_sync(self.data)
async def load_python_module_config(
config_name: str, template_name: str, completion: bool = True
) -> ConfigManager:
async def load_python_module_config(config_name: str, template_name: str, completion: bool = True) -> ConfigManager:
"""加载Python模块配置文件
Args:

View File

@@ -160,9 +160,7 @@ class Application:
"""打印访问 webui 的提示"""
if not os.path.exists(os.path.join('.', 'web/out')):
self.logger.warning(
'WebUI 文件缺失请根据文档获取https://docs.langbot.app/webui/intro.html'
)
self.logger.warning('WebUI 文件缺失请根据文档获取https://docs.langbot.app/webui/intro.html')
return
host_ip = '127.0.0.1'

View File

@@ -26,9 +26,7 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
if constants.debug_mode:
level = logging.DEBUG
log_file_name = 'data/logs/langbot-%s.log' % time.strftime(
'%Y-%m-%d', time.localtime()
)
log_file_name = 'data/logs/langbot-%s.log' % time.strftime('%Y-%m-%d', time.localtime())
qcg_logger = logging.getLogger('langbot')
@@ -43,9 +41,7 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
stream_handler = logging.StreamHandler(sys.stdout)
# stream_handler.setLevel(level)
# stream_handler.setFormatter(color_formatter)
stream_handler.stream = open(
sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1
)
stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
log_handlers: list[logging.Handler] = [
stream_handler,

View File

@@ -87,8 +87,7 @@ class Query(pydantic.BaseModel):
"""使用的函数,由前置处理器阶段设置"""
resp_messages: (
typing.Optional[list[llm_entities.Message]]
| typing.Optional[list[platform_message.MessageChain]]
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]]
) = []
"""由Process阶段生成的回复消息对象列表"""
@@ -130,13 +129,9 @@ class Conversation(pydantic.BaseModel):
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_llm_model: requester.RuntimeLLMModel
@@ -162,17 +157,11 @@ class Session(pydantic.BaseModel):
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = pydantic.Field(
default_factory=list
)
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""

View File

@@ -11,16 +11,14 @@ class SensitiveWordMigration(migration.Migration):
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return os.path.exists(
'data/config/sensitive-words.json'
) and not os.path.exists('data/metadata/sensitive-words.json')
return os.path.exists('data/config/sensitive-words.json') and not os.path.exists(
'data/metadata/sensitive-words.json'
)
async def run(self):
"""执行迁移"""
# 移动文件
os.rename(
'data/config/sensitive-words.json', 'data/metadata/sensitive-words.json'
)
os.rename('data/config/sensitive-words.json', 'data/metadata/sensitive-words.json')
# 重新加载配置
await self.ap.sensitive_meta.load_config()

View File

@@ -23,9 +23,7 @@ class OpenAIConfigMigration(migration.Migration):
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
self.ap.provider_cfg.data['model'] = old_openai_config[
'chat-completions-params'
]['model']
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
del old_openai_config['chat-completions-params']['model']

View File

@@ -15,8 +15,6 @@ class QCGCenterURLConfigMigration(migration.Migration):
"""执行迁移"""
if 'qcg-center-url' not in self.ap.system_cfg.data:
self.ap.system_cfg.data['qcg-center-url'] = (
'https://api.qchatgpt.rockchin.top/api/v2'
)
self.ap.system_cfg.data['qcg-center-url'] = 'https://api.qchatgpt.rockchin.top/api/v2'
await self.ap.system_cfg.dump_config()

View File

@@ -9,9 +9,7 @@ class AdFixwinConfigMigration(migration.Migration):
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return isinstance(
self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int
)
return isinstance(self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int)
async def run(self):
"""执行迁移"""
@@ -19,9 +17,7 @@ class AdFixwinConfigMigration(migration.Migration):
for session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
temp_dict = {
'window-size': 60,
'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][
session_name
],
'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name],
}
self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] = temp_dict

View File

@@ -9,10 +9,7 @@ class HttpApiConfigMigration(migration.Migration):
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return (
'http-api' not in self.ap.system_cfg.data
or 'persistence' not in self.ap.system_cfg.data
)
return 'http-api' not in self.ap.system_cfg.data or 'persistence' not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""

View File

@@ -11,8 +11,7 @@ class DifyAPITimeoutParamsMigration(migration.Migration):
"""判断当前环境是否需要运行此迁移"""
return (
'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat']
or 'timeout'
not in self.ap.provider_cfg.data['dify-service-api']['workflow']
or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow']
or 'agent' not in self.ap.provider_cfg.data['dify-service-api']
)

View File

@@ -10,9 +10,7 @@ class SiliconFlowConfigMigration(migration.Migration):
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return (
'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
)
return 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
async def run(self):
"""执行迁移"""

View File

@@ -13,17 +13,12 @@ class DifyThinkingConfigMigration(migration.Migration):
if 'options' not in self.ap.provider_cfg.data['dify-service-api']:
return True
if (
'convert-thinking-tips'
not in self.ap.provider_cfg.data['dify-service-api']['options']
):
if 'convert-thinking-tips' not in self.ap.provider_cfg.data['dify-service-api']['options']:
return True
return False
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['dify-service-api']['options'] = {
'convert-thinking-tips': 'plain'
}
self.ap.provider_cfg.data['dify-service-api']['options'] = {'convert-thinking-tips': 'plain'}
await self.ap.provider_cfg.dump_config()

View File

@@ -24,8 +24,6 @@ class GewechatFileUrlConfigMigration(migration.Migration):
if adapter['adapter'] == 'gewechat':
if 'gewechat_file_url' not in adapter:
parsed_url = urlparse(adapter['gewechat_url'])
adapter['gewechat_file_url'] = (
f'{parsed_url.scheme}://{parsed_url.hostname}:2532'
)
adapter['gewechat_file_url'] = f'{parsed_url.scheme}://{parsed_url.hostname}:2532'
await self.ap.platform_cfg.dump_config()

View File

@@ -3,24 +3,23 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("tg-dingtalk-markdown", 38)
@migration.migration_class('tg-dingtalk-markdown', 38)
class TgDingtalkMarkdownMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
for adapter in self.ap.platform_cfg.data['platform-adapters']:
if adapter['adapter'] in ['dingtalk','telegram']:
if adapter['adapter'] in ['dingtalk', 'telegram']:
if 'markdown_card' not in adapter:
return True
return False
async def run(self):
"""执行迁移"""
for adapter in self.ap.platform_cfg.data['platform-adapters']:
if adapter['adapter'] in ['dingtalk','telegram']:
if adapter['adapter'] in ['dingtalk', 'telegram']:
if 'markdown_card' not in adapter:
adapter['markdown_card'] = False
await self.ap.platform_cfg.dump_config()

View File

@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("modelscope-config-completion", 39)
@migration.migration_class('modelscope-config-completion', 39)
class ModelScopeConfigCompletionMigration(migration.Migration):
"""ModelScope配置迁移
"""
"""ModelScope配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'modelscope' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['modelscope-chat-completions'] = {
'base-url': 'https://api-inference.modelscope.cn/v1',

View File

@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("ppio-config", 40)
@migration.migration_class('ppio-config', 40)
class PPIOConfigMigration(migration.Migration):
"""PPIO配置迁移
"""
"""PPIO配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'ppio-chat-completions' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'ppio-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'ppio' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'ppio-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['ppio-chat-completions'] = {
'base-url': 'https://api.ppinfra.com/v3/openai',

View File

@@ -35,9 +35,7 @@ class TaskContext:
if action is not None:
self.set_current_action(action)
self._log(
f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}'
)
self._log(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}')
def to_dict(self) -> dict:
return {'current_action': self.current_action, 'log': self.log}
@@ -104,9 +102,7 @@ class TaskWrapper:
name: str = '',
label: str = '',
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [
core_entities.LifecycleControlScope.APPLICATION
],
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
):
self.id = TaskWrapper._id_index
TaskWrapper._id_index += 1
@@ -141,7 +137,9 @@ class TaskWrapper:
exception_traceback = 'Traceback (most recent call last):\n'
for frame in self.task_stack:
exception_traceback += f' File "{frame.f_code.co_filename}", line {frame.f_lineno}, in {frame.f_code.co_name}\n'
exception_traceback += (
f' File "{frame.f_code.co_filename}", line {frame.f_lineno}, in {frame.f_code.co_name}\n'
)
exception_traceback += f' {self.assume_exception().__str__()}\n'
@@ -156,13 +154,9 @@ class TaskWrapper:
'runtime': {
'done': self.task.done(),
'state': self.task._state,
'exception': self.assume_exception().__str__()
if self.assume_exception() is not None
else None,
'exception': self.assume_exception().__str__() if self.assume_exception() is not None else None,
'exception_traceback': exception_traceback,
'result': self.assume_result().__str__()
if self.assume_result() is not None
else None,
'result': self.assume_result().__str__() if self.assume_result() is not None else None,
},
}
@@ -191,13 +185,9 @@ class AsyncTaskManager:
name: str = '',
label: str = '',
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [
core_entities.LifecycleControlScope.APPLICATION
],
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
wrapper = TaskWrapper(
self.ap, coro, task_type, kind, name, label, context, scopes
)
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
self.tasks.append(wrapper)
return wrapper
@@ -208,9 +198,7 @@ class AsyncTaskManager:
name: str = '',
label: str = '',
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [
core_entities.LifecycleControlScope.APPLICATION
],
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
return self.create_task(coro, 'user', kind, name, label, context, scopes)
@@ -225,9 +213,7 @@ class AsyncTaskManager:
type: str = None,
) -> dict:
return {
'tasks': [
t.to_dict() for t in self.tasks if type is None or t.task_type == type
],
'tasks': [t.to_dict() for t in self.tasks if type is None or t.task_type == type],
'id_index': TaskWrapper._id_index,
}

View File

@@ -114,9 +114,7 @@ class Component(pydantic.BaseModel):
_execution: Execution
"""组件执行"""
def __init__(
self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str
):
def __init__(self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str):
super().__init__(
owner=owner,
manifest=manifest,
@@ -125,19 +123,12 @@ class Component(pydantic.BaseModel):
)
self._metadata = Metadata(**manifest['metadata'])
self._spec = manifest['spec']
self._execution = (
Execution(**manifest['execution']) if 'execution' in manifest else None
)
self._execution = Execution(**manifest['execution']) if 'execution' in manifest else None
@classmethod
def is_component_manifest(cls, manifest: typing.Dict[str, typing.Any]) -> bool:
"""判断是否为组件清单"""
return (
'apiVersion' in manifest
and 'kind' in manifest
and 'metadata' in manifest
and 'spec' in manifest
)
return 'apiVersion' in manifest and 'kind' in manifest and 'metadata' in manifest and 'spec' in manifest
@property
def kind(self) -> str:
@@ -200,9 +191,7 @@ class ComponentDiscoveryEngine:
def __init__(self, ap: app.Application):
self.ap = ap
def load_component_manifest(
self, path: str, owner: str = 'builtin', no_save: bool = False
) -> Component | None:
def load_component_manifest(self, path: str, owner: str = 'builtin', no_save: bool = False) -> Component | None:
"""加载组件清单"""
with open(path, 'r', encoding='utf-8') as f:
manifest = yaml.safe_load(f)
@@ -229,18 +218,12 @@ class ComponentDiscoveryEngine:
if depth > max_depth:
return
for file in os.listdir(path):
if (not os.path.isdir(os.path.join(path, file))) and (
file.endswith('.yaml') or file.endswith('.yml')
):
comp = self.load_component_manifest(
os.path.join(path, file), owner, no_save
)
if (not os.path.isdir(os.path.join(path, file))) and (file.endswith('.yaml') or file.endswith('.yml')):
comp = self.load_component_manifest(os.path.join(path, file), owner, no_save)
if comp is not None:
components.append(comp)
elif os.path.isdir(os.path.join(path, file)):
recursive_load_component_manifests_in_dir(
os.path.join(path, file), depth + 1
)
recursive_load_component_manifests_in_dir(os.path.join(path, file), depth + 1)
recursive_load_component_manifests_in_dir(path)
return components
@@ -259,18 +242,12 @@ class ComponentDiscoveryEngine:
for dir in group['fromDirs']:
path = dir['path']
max_depth = dir['maxDepth'] if 'maxDepth' in dir else 1
components.extend(
self.load_component_manifests_in_dir(
path, owner, no_save, max_depth
)
)
components.extend(self.load_component_manifests_in_dir(path, owner, no_save, max_depth))
return components
def discover_blueprint(self, blueprint_manifest_path: str, owner: str = 'builtin'):
"""发现蓝图"""
blueprint_manifest = self.load_component_manifest(
blueprint_manifest_path, owner, no_save=True
)
blueprint_manifest = self.load_component_manifest(blueprint_manifest_path, owner, no_save=True)
if blueprint_manifest is None:
raise ValueError(f'Invalid blueprint manifest: {blueprint_manifest_path}')
assert blueprint_manifest.kind == 'Blueprint', '`Kind` must be `Blueprint`'
@@ -297,9 +274,7 @@ class ComponentDiscoveryEngine:
return []
return self.components[kind]
def find_components(
self, kind: str, component_list: typing.List[Component]
) -> typing.List[Component]:
def find_components(self, kind: str, component_list: typing.List[Component]) -> typing.List[Component]:
"""查找组件"""
result: typing.List[Component] = []
for component in component_list:

View File

@@ -16,9 +16,7 @@ class Bot(Base):
enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
use_pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
use_pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -16,9 +16,7 @@ class LLMModel(Base):
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -11,9 +11,7 @@ class LegacyPipeline(Base):
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,
@@ -35,9 +33,7 @@ class PipelineRunRecord(Base):
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
status = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -13,9 +13,7 @@ class PluginSetting(Base):
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -9,9 +9,7 @@ class User(Base):
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -11,6 +11,4 @@ class SQLiteDatabaseManager(database.BaseDatabaseManager):
async def initialize(self) -> None:
sqlite_path = 'data/langbot.db'
self.engine = sqlalchemy_asyncio.create_async_engine(
f'sqlite+aiosqlite:///{sqlite_path}'
)
self.engine = sqlalchemy_asyncio.create_async_engine(f'sqlite+aiosqlite:///{sqlite_path}')

View File

@@ -58,24 +58,18 @@ class PersistenceManager:
for item in metadata.initial_metadata:
# check if the item exists
result = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(
metadata.Metadata.key == item['key']
)
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
)
row = result.first()
if row is None:
await self.execute_async(
sqlalchemy.insert(metadata.Metadata).values(item)
)
await self.execute_async(sqlalchemy.insert(metadata.Metadata).values(item))
# write default pipeline
result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline))
if result.first() is None:
self.ap.logger.info('Creating default pipeline...')
pipeline_config = json.load(
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
)
pipeline_config = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
pipeline_data = {
'uuid': str(uuid.uuid4()),
@@ -87,16 +81,12 @@ class PersistenceManager:
'config': pipeline_config,
}
await self.execute_async(
sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data)
)
await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data))
# =================================
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(
metadata.Metadata.key == 'database_version'
)
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
@@ -122,17 +112,11 @@ class PersistenceManager:
.values({'value': str(migration_instance.number)})
)
last_migration_number = migration_instance.number
self.ap.logger.info(
f'Migration {migration_instance.number} completed.'
)
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(
f'Successfully upgraded database to version {last_migration_number}.'
)
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def execute_async(
self, *args, **kwargs
) -> sqlalchemy.engine.cursor.CursorResult:
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
async with self.get_db_engine().connect() as conn:
result = await conn.execute(*args, **kwargs)
await conn.commit()
@@ -141,9 +125,7 @@ class PersistenceManager:
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine()
def serialize_model(
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base
) -> dict:
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
return {
column.name: getattr(data, column.name)
if not isinstance(getattr(data, column.name), (datetime.datetime))

View File

@@ -14,9 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
async def initialize(self, pipeline_config: dict):
pass
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
found = False
mode = query.pipeline_config['trigger']['access-control']['mode']
@@ -41,11 +39,7 @@ class BanSessionCheckStage(stage.PipelineStage):
ctn = not found
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE
if ctn
else entities.ResultType.INTERRUPT,
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}'
if not ctn
else '',
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else '',
)

View File

@@ -65,9 +65,7 @@ class ContentFilterStage(stage.PipelineStage):
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
@@ -86,13 +84,9 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = platform_message.MessageChain(
platform_message.Plain(message)
)
query.message_chain = platform_message.MessageChain(platform_message.Plain(message))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def _post_process(
self,
@@ -103,9 +97,7 @@ class ContentFilterStage(stage.PipelineStage):
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
message = message.strip()
for filter in self.filter_chain:
@@ -127,13 +119,9 @@ class ContentFilterStage(stage.PipelineStage):
query.resp_messages[-1].content = message
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
@@ -147,9 +135,7 @@ class ContentFilterStage(stage.PipelineStage):
if contain_non_text:
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage':
@@ -162,8 +148,6 @@ class ContentFilterStage(stage.PipelineStage):
self.ap.logger.debug(
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@@ -60,9 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(
self, query: core_entities.Query, message: str = None, image_url=None
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
"""处理消息
分为前后阶段,具体取决于 enable_stages 的值。

View File

@@ -21,19 +21,13 @@ class BaiduCloudExamine(filter_model.ContentFilter):
BAIDU_EXAMINE_TOKEN_URL,
params={
'grant_type': 'client_credentials',
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-key'
],
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-secret'
],
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
},
) as resp:
return (await resp.json())['access_token']
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),

View File

@@ -13,9 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self):
pass
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
found = False
for word in self.ap.sensitive_meta.data['words']:
@@ -31,9 +29,7 @@ class BanWordFilter(filter_model.ContentFilter):
self.ap.sensitive_meta.data['mask'] * len(match[i]),
)
else:
message = message.replace(
match[i], self.ap.sensitive_meta.data['mask_word']
)
message = message.replace(match[i], self.ap.sensitive_meta.data['mask_word'])
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,

View File

@@ -16,9 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE,
]
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule):

View File

@@ -16,9 +16,7 @@ class Controller:
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(
self.ap.instance_config.data['concurrency']['pipeline']
)
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
async def consumer(self):
"""事件处理循环"""
@@ -32,9 +30,7 @@ class Controller:
for query in queries:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(
f'Checking query {query} session {session}'
)
self.ap.logger.debug(f'Checking query {query} session {session}')
if not session.semaphore.locked():
selected_query = query
@@ -55,22 +51,16 @@ class Controller:
# find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
bot = await self.ap.platform_mgr.get_bot_by_uuid(
selected_query.bot_uuid
)
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
if bot:
pipeline = (
await self.ap.pipeline_mgr.get_pipeline_by_uuid(
bot.bot_entity.use_pipeline_uuid
)
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(
bot.bot_entity.use_pipeline_uuid
)
if pipeline:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(
await self.ap.sess_mgr.get_session(selected_query)
).semaphore.release()
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()

View File

@@ -47,9 +47,7 @@ class LongTextProcessStage(stage.PipelineStage):
'未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'
)
pipeline_config['output']['long-text-processing'][
'strategy'
] = 'forward'
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
except Exception:
traceback.print_exc()
self.ap.logger.error(
@@ -58,9 +56,7 @@ class LongTextProcessStage(stage.PipelineStage):
)
)
pipeline_config['output']['long-text-processing']['strategy'] = (
'forward'
)
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
@@ -71,9 +67,7 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize()
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
# 检查是否包含非 Plain 组件
contains_non_plain = False
@@ -89,11 +83,7 @@ class LongTextProcessStage(stage.PipelineStage):
> query.pipeline_config['output']['long-text-processing']['threshold']
):
query.resp_message_chain[-1] = platform_message.MessageChain(
await self.strategy_impl.process(
str(query.resp_message_chain[-1]), query
)
await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -13,9 +13,7 @@ Forward = platform_message.Forward
@strategy_model.strategy_class('forward')
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay(
title='群聊的聊天记录',
brief='[聊天记录]',

View File

@@ -27,18 +27,14 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
encoding='utf-8',
)
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time())),
query=query,
)
compressed_path, size = self.compress_image(
img_path, outfile='temp/{}_compressed.png'.format(int(time.time()))
)
compressed_path, size = self.compress_image(img_path, outfile='temp/{}_compressed.png'.format(int(time.time())))
with open(compressed_path, 'rb') as f:
img = f.read()
@@ -165,10 +161,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
numbers = self.indexNumber(rest_text)
for number in numbers:
if (
number[1] < point < number[1] + len(number[0])
and number[1] != 0
):
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
point = number[1]
break
@@ -181,9 +174,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
else:
continue
# 准备画布
img = Image.new(
'RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)
)
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA')
self.ap.logger.debug('正在绘制图片...')

View File

@@ -49,9 +49,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
"""处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法

View File

@@ -29,12 +29,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
else:
raise ValueError(f'未知的截断器: {use_method}')
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -79,26 +79,20 @@ class RuntimePipeline:
query.pipeline_config = self.pipeline_entity.config
await self.process_query(query)
async def _check_output(
self, query: entities.Query, result: pipeline_entities.StageProcessResult
):
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出"""
if result.user_notice:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain(
platform_message.Plain(result.user_notice)
)
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(*result.user_notice)
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
result.user_notice.insert(
0, platform_message.At(query.message_event.sender.id)
)
result.user_notice.insert(0, platform_message.At(query.message_event.sender.id))
await query.adapter.reply_message(
message_source=query.message_event,
@@ -150,37 +144,25 @@ class RuntimePipeline:
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} res {result}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {result}')
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} interrupted query {query}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} gen'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} gen')
async for sub_result in result:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} res {sub_result}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {sub_result}')
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} interrupted query {query}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
break
elif (
sub_result.result_type == pipeline_entities.ResultType.CONTINUE
):
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
@@ -214,12 +196,8 @@ class RuntimePipeline:
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = (
query.current_stage.inst_name if query.current_stage else 'unknown'
)
self.ap.logger.error(
f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}'
)
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
self.ap.logger.debug(f'Traceback: {traceback.format_exc()}')
finally:
self.ap.logger.debug(f'Query {query} processed')
@@ -241,18 +219,14 @@ class PipelineManager:
self.pipelines = []
async def initialize(self):
self.stage_dict = {
name: cls for name, cls in stage.preregistered_stages.items()
}
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
await self.load_pipelines_from_db()
async def load_pipelines_from_db(self):
self.ap.logger.info('Loading pipelines from db...')
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
pipelines = result.all()
@@ -267,20 +241,14 @@ class PipelineManager:
| dict,
):
if isinstance(pipeline_entity, sqlalchemy.Row):
pipeline_entity = persistence_pipeline.LegacyPipeline(
**pipeline_entity._mapping
)
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
elif isinstance(pipeline_entity, dict):
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
# initialize stage containers according to pipeline_entity.stages
stage_containers: list[StageInstContainer] = []
for stage_name in pipeline_entity.stages:
stage_containers.append(
StageInstContainer(
inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)
)
)
stage_containers.append(StageInstContainer(inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)))
for stage_container in stage_containers:
await stage_container.inst.initialize(pipeline_entity.config)

View File

@@ -44,9 +44,7 @@ class PreProcessor(stage.PipelineStage):
query.use_llm_model = conversation.use_llm_model
query.use_funcs = (
conversation.use_funcs
if query.use_llm_model.model_entity.abilities.__contains__('tool_call')
else None
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
)
query.variables = {
@@ -59,10 +57,9 @@ class PreProcessor(stage.PipelineStage):
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if (
query.pipeline_config['ai']['runner']['runner'] == 'local-agent'
and not query.use_llm_model.model_entity.abilities.__contains__('vision')
):
if query.pipeline_config['ai']['runner'][
'runner'
] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
@@ -78,14 +75,11 @@ class PreProcessor(stage.PipelineStage):
content_list.append(llm_entities.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if (
query.pipeline_config['ai']['runner']['runner'] != 'local-agent'
or query.use_llm_model.model_entity.abilities.__contains__('vision')
):
if query.pipeline_config['ai']['runner'][
'runner'
] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_base64(me.base64)
)
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
query.variables['user_message_text'] = plain_text
@@ -104,6 +98,4 @@ class PreProcessor(stage.PipelineStage):
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -49,13 +49,9 @@ class ChatMessageHandler(handler.MessageHandler):
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
if event_ctx.event.alter is not None:
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
@@ -69,34 +65,24 @@ class ChatMessageHandler(handler.MessageHandler):
runner = r(self.ap, query.pipeline_config)
break
else:
raise ValueError(
f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}'
)
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
async for result in runner.run(query):
query.resp_messages.append(result)
self.ap.logger.info(
f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}'
)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e:
self.ap.logger.error(
f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}'
)
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
hide_exception_info = query.pipeline_config['output']['misc'][
'hide-exception'
]
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,

View File

@@ -21,10 +21,7 @@ class CommandHandler(handler.MessageHandler):
privilege = 1
if (
f'{query.launcher_type.value}_{query.launcher_id}'
in self.ap.instance_config.data['admins']
):
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
privilege = 2
spt = command_text.split(' ')
@@ -54,25 +51,17 @@ class CommandHandler(handler.MessageHandler):
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
if event_ctx.event.alter is not None:
query.message_chain = platform_message.MessageChain(
[platform_message.Plain(event_ctx.event.alter)]
)
query.message_chain = platform_message.MessageChain([platform_message.Plain(event_ctx.event.alter)])
session = await self.ap.sess_mgr.get_session(query)
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text, query=query, session=session
):
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
if ret.error is not None:
query.resp_messages.append(
llm_entities.Message(
@@ -81,13 +70,9 @@ class CommandHandler(handler.MessageHandler):
)
)
self.ap.logger.info(
f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}'
)
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif ret.text is not None or ret.image_url is not None:
content: list[llm_entities.ContentElement] = []
@@ -95,9 +80,7 @@ class CommandHandler(handler.MessageHandler):
content.append(llm_entities.ContentElement.from_text(ret.text))
if ret.image_url is not None:
content.append(
llm_entities.ContentElement.from_image_url(ret.image_url)
)
content.append(llm_entities.ContentElement.from_image_url(ret.image_url))
query.resp_messages.append(
llm_entities.Message(
@@ -108,10 +91,6 @@ class CommandHandler(handler.MessageHandler):
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)

View File

@@ -72,9 +72,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
if count >= limitation:
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
return False
elif (
query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait'
):
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
# 等待下一窗口
await asyncio.sleep(window_size - time.time() % window_size)

View File

@@ -15,9 +15,7 @@ from ...core import entities as core_entities
class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息"""
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
random_range = (
@@ -34,9 +32,7 @@ class SendResponseBackStage(stage.PipelineStage):
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
query.resp_message_chain[-1].insert(
0, platform_message.At(query.message_event.sender.id)
)
query.resp_message_chain[-1].insert(0, platform_message.At(query.message_event.sender.id))
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
@@ -46,6 +42,4 @@ class SendResponseBackStage(stage.PipelineStage):
quote_origin=quote_origin,
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -32,13 +32,9 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
rules = query.pipeline_config['trigger']['group-respond-rules']
@@ -49,9 +45,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
# use_rule = rules[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(
str(query.message_chain), query.message_chain, use_rule, query
)
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
if res.matching:
query.message_chain = res.replacement
@@ -60,6 +54,4 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
new_query=query,
)
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)

View File

@@ -16,10 +16,7 @@ class AtBotRule(rule_model.GroupRespondRule):
rule_dict: dict,
query: core_entities.Query,
) -> entities.RuleJudgeResult:
if (
message_chain.has(platform_message.At(query.adapter.bot_account_id))
and rule_dict['at']
):
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
if message_chain.has(

View File

@@ -18,6 +18,4 @@ class RandomRespRule(rule_model.GroupRespondRule):
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random']
return entities.RuleJudgeResult(
matching=random.random() < random_rate, replacement=message_chain
)
return entities.RuleJudgeResult(matching=random.random() < random_rate, replacement=message_chain)

View File

@@ -34,29 +34,19 @@ class ResponseWrapper(stage.PipelineStage):
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
if query.resp_messages[-1].role == 'command':
query.resp_message_chain.append(
query.resp_messages[-1].get_content_platform_message_chain(
prefix_text='[bot] '
)
query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] ')
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif query.resp_messages[-1].role == 'plugin':
query.resp_message_chain.append(
query.resp_messages[-1].get_content_platform_message_chain()
)
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
if query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
@@ -77,9 +67,7 @@ class ResponseWrapper(stage.PipelineStage):
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[
fc.function.name for fc in result.tool_calls
]
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
@@ -92,36 +80,26 @@ class ResponseWrapper(stage.PipelineStage):
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(
platform_message.MessageChain(event_ctx.event.reply)
)
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain.append(
result.get_content_platform_message_chain()
)
query.resp_message_chain.append(result.get_content_platform_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
if (
result.tool_calls is not None and len(result.tool_calls) > 0
): # 有函数调用
if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(
platform_message.MessageChain(
[platform_message.Plain(reply_text)]
)
platform_message.MessageChain([platform_message.Plain(reply_text)])
)
if query.pipeline_config['output']['misc'][
'track-function-calls'
]:
if query.pipeline_config['output']['misc']['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
@@ -131,9 +109,7 @@ class ResponseWrapper(stage.PipelineStage):
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[
fc.function.name for fc in result.tool_calls
]
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
@@ -148,16 +124,12 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(
platform_message.MessageChain(
event_ctx.event.reply
)
platform_message.MessageChain(event_ctx.event.reply)
)
else:
query.resp_message_chain.append(
platform_message.MessageChain(
[platform_message.Plain(reply_text)]
)
platform_message.MessageChain([platform_message.Plain(reply_text)])
)
yield entities.StageProcessResult(

View File

@@ -32,9 +32,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
self.config = config
self.ap = ap
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息
Args:
@@ -66,9 +64,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
def register_listener(
self,
event_type: typing.Type[platform_message.Event],
callback: typing.Callable[
[platform_message.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
):
"""注册事件监听器
@@ -81,9 +77,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
def unregister_listener(
self,
event_type: typing.Type[platform_message.Event],
callback: typing.Callable[
[platform_message.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
):
"""注销事件监听器

View File

@@ -132,14 +132,10 @@ class PlatformManager:
self.adapter_dict = {}
async def initialize(self):
self.adapter_components = self.ap.discover.get_components_by_kind(
'MessagePlatformAdapter'
)
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {}
for component in self.adapter_components:
adapter_dict[component.metadata.name] = (
component.get_python_component_class()
)
adapter_dict[component.metadata.name] = component.get_python_component_class()
self.adapter_dict = adapter_dict
await self.load_bots_from_db()
@@ -152,9 +148,7 @@ class PlatformManager:
self.bots = []
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
bots = result.all()
@@ -172,13 +166,9 @@ class PlatformManager:
elif isinstance(bot_entity, dict):
bot_entity = persistence_bot.Bot(**bot_entity)
adapter_inst = self.adapter_dict[bot_entity.adapter](
bot_entity.adapter_config, self.ap
)
adapter_inst = self.adapter_dict[bot_entity.adapter](bot_entity.adapter_config, self.ap)
runtime_bot = RuntimeBot(
ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst
)
runtime_bot = RuntimeBot(ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst)
await runtime_bot.initialize()
@@ -209,9 +199,7 @@ class PlatformManager:
return component.to_plain_dict()
return None
def get_available_adapter_manifest_by_name(
self, name: str
) -> engine.Component | None:
def get_available_adapter_manifest_by_name(self, name: str) -> engine.Component | None:
for component in self.adapter_components:
if component.metadata.name == name:
return component

View File

@@ -58,13 +58,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
msg_list.append(aiocqhttp.MessageSegment.record(msg.path))
elif type(msg) is platform_message.Forward:
for node in msg.node_list:
msg_list.extend(
(
await AiocqhttpMessageConverter.yiri2target(
node.message_chain
)
)[0]
)
msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0])
else:
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
@@ -77,9 +71,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
for msg in message:
if msg.type == 'at':
@@ -94,14 +86,8 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.type == 'text':
yiri_msg_list.append(platform_message.Plain(text=msg.data['text']))
elif msg.type == 'image':
image_base64, image_format = await image.qq_image_url_to_base64(
msg.data['url']
)
yiri_msg_list.append(
platform_message.Image(
base64=f'data:image/{image_format};base64,{image_base64}'
)
)
image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url'])
yiri_msg_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
chain = platform_message.MessageChain(yiri_msg_list)
@@ -115,9 +101,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod
async def target2yiri(event: aiocqhttp.Event):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(
event.message, event.message_id
)
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id)
if event.message_type == 'group':
permission = 'MEMBER'
@@ -137,9 +121,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
name=event.sender['nickname'],
permission=platform_entities.Permission.Member,
),
special_title=event.sender['title']
if 'title' in event.sender
else '',
special_title=event.sender['title'] if 'title' in event.sender else '',
join_timestamp=0,
last_speak_timestamp=0,
mute_time_remaining=0,
@@ -191,9 +173,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
else:
self.bot = aiocqhttp.CQHttp()
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if target_type == 'group':
@@ -207,14 +187,10 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
message: platform_message.MessageChain,
quote_origin: bool = False,
):
aiocq_event = await AiocqhttpEventConverter.yiri2target(
message_source, self.bot_account_id
)
aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if quote_origin:
aiocq_msg = (
aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg
)
aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg
return await self.bot.send(aiocq_event, aiocq_msg)
@@ -224,16 +200,12 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id
try:
return await callback(
await self.event_converter.target2yiri(event), self
)
return await callback(await self.event_converter.target2yiri(event), self)
except Exception:
traceback.print_exc()
@@ -245,9 +217,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -22,9 +22,7 @@ class DingTalkMessageConverter(adapter.MessageConverter):
async def target2yiri(event: DingTalkEvent, bot_name: str):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(
id=event.incoming_message.message_id, time=datetime.datetime.now()
)
platform_message.Source(id=event.incoming_message.message_id, time=datetime.datetime.now())
)
for atUser in event.incoming_message.at_users:
@@ -133,9 +131,7 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
content = await DingTalkMessageConverter.yiri2target(message)
await self.bot.send_message(content, incoming_message)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
content = await DingTalkMessageConverter.yiri2target(message)
if target_type == 'person':
await self.bot.send_proactive_message_to_one(target_id, content)
@@ -145,16 +141,12 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: DingTalkEvent):
try:
return await callback(
await self.event_converter.target2yiri(
event, self.config['robot_name']
),
await self.event_converter.target2yiri(event, self.config['robot_name']),
self,
)
except Exception:
@@ -174,8 +166,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
async def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -45,9 +45,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
with open(ele.path, 'rb') as f:
image_bytes = f.read()
image_files.append(
discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png')
)
image_files.append(discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png'))
elif isinstance(ele, platform_message.Plain):
text_string += ele.text
elif isinstance(ele, platform_message.Forward):
@@ -65,9 +63,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
async def target2yiri(message: discord.Message) -> platform_message.MessageChain:
lb_msg_list = []
msg_create_time = datetime.datetime.fromtimestamp(
int(message.created_at.timestamp())
)
msg_create_time = datetime.datetime.fromtimestamp(int(message.created_at.timestamp()))
lb_msg_list.append(platform_message.Source(id=message.id, time=msg_create_time))
@@ -97,11 +93,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
else:
mid_at_component.append(platform_message.At(target=mid_at[2:-1]))
return (
text_element_recur(text_split[0])
+ mid_at_component
+ text_element_recur(text_split[1])
)
return text_element_recur(text_split[0]) + mid_at_component + text_element_recur(text_split[1])
else:
return [platform_message.Plain(text=text_ele)]
@@ -114,11 +106,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
image_data = await response.read()
image_base64 = base64.b64encode(image_data).decode('utf-8')
image_format = response.headers['Content-Type']
element_list.append(
platform_message.Image(
base64=f'data:{image_format};base64,{image_base64}'
)
)
element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
return platform_message.MessageChain(element_list)
@@ -208,9 +196,7 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
self.bot = MyClient(intents=intents, **args)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
async def reply_message(
@@ -243,18 +229,14 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners.pop(event_type)

View File

@@ -40,14 +40,10 @@ class GewechatMessageConverter(adapter.MessageConverter):
content_list.append({'type': 'image', 'image': component.url})
elif isinstance(component, platform_message.Voice):
content_list.append(
{'type': 'voice', 'url': component.url, 'length': component.length}
)
content_list.append({'type': 'voice', 'url': component.url, 'length': component.length})
elif isinstance(component, platform_message.Forward):
for node in component.node_list:
content_list.extend(
await GewechatMessageConverter.yiri2target(node.message_chain)
)
content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain))
content_list.append({'type': 'image', 'image': component.url})
elif isinstance(component, platform_message.WeChatMiniPrograms):
content_list.append(
@@ -88,44 +84,26 @@ class GewechatMessageConverter(adapter.MessageConverter):
}
)
elif isinstance(component, platform_message.WeChatForwardLink):
content_list.append(
{'type': 'WeChatForwardLink', 'xml_data': component.xml_data}
)
content_list.append({'type': 'WeChatForwardLink', 'xml_data': component.xml_data})
elif isinstance(component, platform_message.Voice):
content_list.append(
{'type': 'voice', 'url': component.url, 'length': component.length}
)
content_list.append({'type': 'voice', 'url': component.url, 'length': component.length})
elif isinstance(component, platform_message.WeChatForwardImage):
content_list.append(
{'type': 'WeChatForwardImage', 'xml_data': component.xml_data}
)
content_list.append({'type': 'WeChatForwardImage', 'xml_data': component.xml_data})
elif isinstance(component, platform_message.WeChatForwardFile):
content_list.append(
{'type': 'WeChatForwardFile', 'xml_data': component.xml_data}
)
content_list.append({'type': 'WeChatForwardFile', 'xml_data': component.xml_data})
elif isinstance(component, platform_message.WeChatAppMsg):
content_list.append(
{'type': 'WeChatAppMsg', 'app_msg': component.app_msg}
)
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
# 引用消息转发
elif isinstance(component, platform_message.WeChatForwardQuote):
content_list.append(
{'type': 'WeChatAppMsg', 'app_msg': component.app_msg}
)
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
elif isinstance(component, platform_message.Forward):
for node in component.node_list:
if node.message_chain:
content_list.extend(
await GewechatMessageConverter.yiri2target(
node.message_chain
)
)
content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain))
return content_list
async def target2yiri(
self, message: dict, bot_account_id: str
) -> platform_message.MessageChain:
async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain:
"""外部消息转平台消息"""
# 数据预处理
message_list = []
@@ -163,28 +141,20 @@ class GewechatMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_list)
async def _handler_text(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理文本消息 (msg_type=1)"""
if message and self._is_group_message(message):
pattern = r'@\S{1,20}'
content_no_preifx = re.sub(pattern, '', content_no_preifx)
return platform_message.MessageChain(
[platform_message.Plain(content_no_preifx)]
)
return platform_message.MessageChain([platform_message.Plain(content_no_preifx)])
async def _handler_image(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理图像消息 (msg_type=3)"""
try:
image_xml = content_no_preifx
if not image_xml:
return platform_message.MessageChain(
[platform_message.Unknown('[图片内容为空]')]
)
return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')])
base64_str, image_format = await image.get_gewechat_image_base64(
gewechat_url=self.config['gewechat_url'],
@@ -196,21 +166,15 @@ class GewechatMessageConverter(adapter.MessageConverter):
)
elements = [
platform_message.Image(
base64=f'data:image/{image_format};base64,{base64_str}'
),
platform_message.Image(base64=f'data:image/{image_format};base64,{base64_str}'),
platform_message.WeChatForwardImage(xml_data=image_xml), # 微信消息转发
]
return platform_message.MessageChain(elements)
except Exception as e:
print(f'处理图片失败: {str(e)}')
return platform_message.MessageChain(
[platform_message.Unknown('[图片处理失败]')]
)
return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')])
async def _handler_voice(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理语音消息 (msg_type=34)"""
message_List = []
try:
@@ -223,9 +187,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_List)
# 转换为平台支持的语音格式(如 Silk 格式)
voice_element = platform_message.Voice(
base64=f'data:audio/silk;base64,{audio_base64}'
)
voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}')
message_List.append(voice_element)
except KeyError as e:
@@ -237,9 +199,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_List)
async def _handler_compound(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理复合消息 (msg_type=49),根据子类型分派"""
try:
xml_data = ET.fromstring(content_no_preifx)
@@ -254,33 +214,21 @@ class GewechatMessageConverter(adapter.MessageConverter):
'6': self._handler_compound_file,
'33': self._handler_compound_mini_program,
'36': self._handler_compound_mini_program,
'2000': partial(
self._handler_compound_unsupported, text='[转账消息]'
),
'2001': partial(
self._handler_compound_unsupported, text='[红包消息]'
),
'51': partial(
self._handler_compound_unsupported, text='[视频号消息]'
),
'2000': partial(self._handler_compound_unsupported, text='[转账消息]'),
'2001': partial(self._handler_compound_unsupported, text='[红包消息]'),
'51': partial(self._handler_compound_unsupported, text='[视频号消息]'),
}
handler = sub_handler_map.get(
data_type, self._handler_compound_unsupported
)
handler = sub_handler_map.get(data_type, self._handler_compound_unsupported)
return await handler(
message=message, # 原始msg
xml_data=xml_data, # xml数据
)
else:
return platform_message.MessageChain(
[platform_message.Unknown(text=content_no_preifx)]
)
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
except Exception as e:
print(f'解析复合消息失败: {str(e)}')
return platform_message.MessageChain(
[platform_message.Unknown(text=content_no_preifx)]
)
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
async def _handler_compound_quote(
self, message: Optional[dict], xml_data: ET.Element
@@ -296,9 +244,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
user_data = appmsg_data.findtext('.//title') or ''
quote_data = appmsg_data.find('.//refermsg').findtext('.//content')
message_list.append(
platform_message.WeChatForwardQuote(
app_msg=ET.tostring(appmsg_data, encoding='unicode')
)
platform_message.WeChatForwardQuote(app_msg=ET.tostring(appmsg_data, encoding='unicode'))
)
# quote_data原始的消息
if quote_data:
@@ -311,22 +257,14 @@ class GewechatMessageConverter(adapter.MessageConverter):
# 引用消息展开
quote_data_xml = ET.fromstring(quote_data)
if quote_data_xml.find('img'):
quote_data_message_list.extend(
await self._handler_image(None, quote_data)
)
quote_data_message_list.extend(await self._handler_image(None, quote_data))
elif quote_data_xml.find('voicemsg'):
quote_data_message_list.extend(
await self._handler_voice(None, quote_data)
)
quote_data_message_list.extend(await self._handler_voice(None, quote_data))
elif quote_data_xml.find('videomsg'):
quote_data_message_list.extend(
await self._handler_default(None, quote_data)
) # 先不处理
quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理
else:
# appmsg
quote_data_message_list.extend(
await self._handler_compound(None, quote_data)
)
quote_data_message_list.extend(await self._handler_compound(None, quote_data))
except Exception as e:
print(f'处理引用消息异常 expcetion:{e}')
quote_data_message_list.append(platform_message.Plain(quote_data))
@@ -351,18 +289,12 @@ class GewechatMessageConverter(adapter.MessageConverter):
# print(f"quote_message_chain plain [msg_type={comp.type}][message={comp.text}]")
return platform_message.MessageChain(message_list)
async def _handler_compound_file(
self, message: dict, xml_data: ET.Element
) -> platform_message.MessageChain:
async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
"""处理文件消息 (data_type=6)"""
xml_data_str = ET.tostring(xml_data, encoding='unicode')
return platform_message.MessageChain(
[platform_message.WeChatForwardFile(xml_data=xml_data_str)]
)
return platform_message.MessageChain([platform_message.WeChatForwardFile(xml_data=xml_data_str)])
async def _handler_compound_link(
self, message: dict, xml_data: ET.Element
) -> platform_message.MessageChain:
async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
"""处理链接消息(如公众号文章、外部网页)"""
message_list = []
try:
@@ -381,9 +313,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
# 转发消息
xml_data_str = ET.tostring(xml_data, encoding='unicode')
# print(xml_data_str)
message_list.append(
platform_message.WeChatForwardLink(xml_data=xml_data_str)
)
message_list.append(platform_message.WeChatForwardLink(xml_data=xml_data_str))
except Exception as e:
print(f'解析链接消息失败: {str(e)}')
return platform_message.MessageChain(message_list)
@@ -393,21 +323,15 @@ class GewechatMessageConverter(adapter.MessageConverter):
) -> platform_message.MessageChain:
"""处理小程序消息(如小程序卡片、服务通知)"""
xml_data_str = ET.tostring(xml_data, encoding='unicode')
return platform_message.MessageChain(
[platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)]
)
return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)])
async def _handler_default(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理未知消息类型"""
if message:
msg_type = message['Data']['MsgType']
else:
msg_type = ''
return platform_message.MessageChain(
[platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')]
)
return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')])
def _handler_compound_unsupported(
self, message: dict, xml_data: str, text: Optional[str] = None
@@ -416,11 +340,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
if not text:
text = f'[xml_data={xml_data}]'
content_list = []
content_list.append(
platform_message.Unknown(
text=f'[处理未支持复合消息类型[msg_type=49]|{text}'
)
)
content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}'))
return platform_message.MessageChain(content_list)
@@ -448,9 +368,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
appmsg_data = xml_data.find('.//appmsg')
tousername = message['Wxid']
if appmsg_data: # 接收方: 所属微信的wxid
quote_id = appmsg_data.find('.//refermsg').findtext(
'.//chatusr'
) # 引用消息的原发送者
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者
ats_bot = ats_bot or (quote_id == tousername)
except Exception as e:
print(f'_ats_bot got except: {e}')
@@ -458,9 +376,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
return ats_bot
# 提取一下content前面的sender_id, 和去掉前缀的内容
def _extract_content_and_sender(
self, raw_content: str
) -> Tuple[str, Optional[str]]:
def _extract_content_and_sender(self, raw_content: str) -> Tuple[str, Optional[str]]:
try:
# 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉
# add: 有些用户的wxid不是上述格式。换成user_name:
@@ -490,21 +406,17 @@ class GewechatEventConverter(adapter.EventConverter):
async def yiri2target(event: platform_events.MessageEvent) -> dict:
pass
async def target2yiri(
self, event: dict, bot_account_id: str
) -> platform_events.MessageEvent:
async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent:
# print(event)
# 排除自己发消息回调回答问题
if event['Wxid'] == event['Data']['FromUserName']['string']:
return None
# 排除公众号以及微信团队消息
if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data'][
'FromUserName'
]['string'].startswith('weixin'):
if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data']['FromUserName'][
'string'
].startswith('weixin'):
return None
message_chain = await self.message_converter.target2yiri(
copy.deepcopy(event), bot_account_id
)
message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id)
if not message_chain:
return None
@@ -589,9 +501,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
return 'ok'
elif 'TypeName' in data and data['TypeName'] == 'AddMsg':
try:
event = await self.event_converter.target2yiri(
data.copy(), self.bot_account_id
)
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
except Exception:
traceback.print_exc()
@@ -600,9 +510,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
return 'ok'
async def _handle_message(
self, message: platform_message.MessageChain, target_id: str
):
async def _handle_message(self, message: platform_message.MessageChain, target_id: str):
"""统一消息处理核心逻辑"""
content_list = await self.message_converter.yiri2target(message)
at_targets = [item['target'] for item in content_list if item['type'] == 'at']
@@ -611,9 +519,9 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
at_targets = at_targets or []
member_info = []
if at_targets:
member_info = self.bot.get_chatroom_member_detail(
self.config['app_id'], target_id, at_targets[::-1]
)['data']
member_info = self.bot.get_chatroom_member_detail(self.config['app_id'], target_id, at_targets[::-1])[
'data'
]
# 处理消息组件
for msg in content_list:
@@ -694,9 +602,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息"""
return await self._handle_message(message, target_id)
@@ -708,9 +614,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
):
"""回复消息"""
if message_source.source_platform_object:
target_id = message_source.source_platform_object['Data']['FromUserName'][
'string'
]
target_id = message_source.source_platform_object['Data']['FromUserName']['string']
return await self._handle_message(message, target_id)
async def is_muted(self, group_id: int) -> bool:
@@ -719,18 +623,14 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
pass
@@ -742,14 +642,10 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
json={'app_id': self.config['app_id']},
) as response:
if response.status != 200:
raise Exception(
f'获取gewechat token失败: {await response.text()}'
)
raise Exception(f'获取gewechat token失败: {await response.text()}')
self.config['token'] = (await response.json())['data']
self.bot = gewechat_client.GewechatClient(
f'{self.config["gewechat_url"]}/v2/api', self.config['token']
)
self.bot = gewechat_client.GewechatClient(f'{self.config["gewechat_url"]}/v2/api', self.config['token'])
def gewechat_login_process():
app_id, error_msg = self.bot.login(self.config['app_id'])

View File

@@ -71,14 +71,10 @@ class LarkMessageConverter(adapter.MessageConverter):
pending_paragraph.append({'tag': 'md', 'text': text})
except UnicodeError:
# If still fails, replace invalid characters
text = msg.text.encode('utf-8', errors='replace').decode(
'utf-8'
)
text = msg.text.encode('utf-8', errors='replace').decode('utf-8')
pending_paragraph.append({'tag': 'md', 'text': text})
elif isinstance(msg, platform_message.At):
pending_paragraph.append(
{'tag': 'at', 'user_id': msg.target, 'style': []}
)
pending_paragraph.append({'tag': 'at', 'user_id': msg.target, 'style': []})
elif isinstance(msg, platform_message.AtAll):
pending_paragraph.append({'tag': 'at', 'user_id': 'all', 'style': []})
elif isinstance(msg, platform_message.Image):
@@ -166,11 +162,7 @@ class LarkMessageConverter(adapter.MessageConverter):
os.unlink(temp_file.name)
elif isinstance(msg, platform_message.Forward):
for node in msg.node_list:
message_elements.extend(
await LarkMessageConverter.yiri2target(
node.message_chain, api_client
)
)
message_elements.extend(await LarkMessageConverter.yiri2target(node.message_chain, api_client))
if pending_paragraph:
message_elements.append(pending_paragraph)
@@ -186,13 +178,9 @@ class LarkMessageConverter(adapter.MessageConverter):
lb_msg_list = []
msg_create_time = datetime.datetime.fromtimestamp(
int(message.create_time) / 1000
)
msg_create_time = datetime.datetime.fromtimestamp(int(message.create_time) / 1000)
lb_msg_list.append(
platform_message.Source(id=message.message_id, time=msg_create_time)
)
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
if message.message_type == 'text':
element_list = []
@@ -222,9 +210,7 @@ class LarkMessageConverter(adapter.MessageConverter):
left_text = text_split[0]
right_text = text_split[1]
new_list.extend(
text_element_recur({'tag': 'text', 'text': left_text, 'style': []})
)
new_list.extend(text_element_recur({'tag': 'text', 'text': left_text, 'style': []}))
new_list.append(
{
@@ -235,15 +221,11 @@ class LarkMessageConverter(adapter.MessageConverter):
}
)
new_list.extend(
text_element_recur({'tag': 'text', 'text': right_text, 'style': []})
)
new_list.extend(text_element_recur({'tag': 'text', 'text': right_text, 'style': []}))
return new_list
element_list = text_element_recur(
{'tag': 'text', 'text': message_content['text'], 'style': []}
)
element_list = text_element_recur({'tag': 'text', 'text': message_content['text'], 'style': []})
message_content = {'title': '', 'content': element_list}
@@ -258,9 +240,7 @@ class LarkMessageConverter(adapter.MessageConverter):
message_content['content'] = new_list
elif message.message_type == 'image':
message_content['content'] = [
{'tag': 'img', 'image_key': message_content['image_key'], 'style': []}
]
message_content['content'] = [{'tag': 'img', 'image_key': message_content['image_key'], 'style': []}]
for ele in message_content['content']:
if ele['tag'] == 'text':
@@ -278,9 +258,7 @@ class LarkMessageConverter(adapter.MessageConverter):
.build()
)
response: GetMessageResourceResponse = (
await api_client.im.v1.message_resource.aget(request)
)
response: GetMessageResourceResponse = await api_client.im.v1.message_resource.aget(request)
if not response.success():
raise Exception(
@@ -292,11 +270,7 @@ class LarkMessageConverter(adapter.MessageConverter):
image_format = response.raw.headers['content-type']
lb_msg_list.append(
platform_message.Image(
base64=f'data:{image_format};base64,{image_base64}'
)
)
lb_msg_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
return platform_message.MessageChain(lb_msg_list)
@@ -312,9 +286,7 @@ class LarkEventConverter(adapter.EventConverter):
async def target2yiri(
event: lark_oapi.im.v1.P2ImMessageReceiveV1, api_client: lark_oapi.Client
) -> platform_events.Event:
message_chain = await LarkMessageConverter.target2yiri(
event.event.message, api_client
)
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
if event.event.message.chat_type == 'p2p':
return platform_events.FriendMessage(
@@ -402,9 +374,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
p2v1.schema = context.schema
if 'im.message.receive_v1' == type:
try:
event = await self.event_converter.target2yiri(
p2v1, self.api_client
)
event = await self.event_converter.target2yiri(p2v1, self.api_client)
except Exception:
traceback.print_exc()
@@ -425,26 +395,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
asyncio.create_task(on_message(event))
event_handler = (
lark_oapi.EventDispatcherHandler.builder('', '')
.register_p2_im_message_receive_v1(sync_on_message)
.build()
lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build()
)
self.bot_account_id = config['bot_name']
self.bot = lark_oapi.ws.Client(
config['app_id'], config['app_secret'], event_handler=event_handler
)
self.api_client = (
lark_oapi.Client.builder()
.app_id(config['app_id'])
.app_secret(config['app_secret'])
.build()
)
self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
async def reply_message(
@@ -455,9 +414,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
):
# 不再需要了因为message_id已经被包含到message_chain中
# lark_event = await self.event_converter.yiri2target(message_source)
lark_message = await self.message_converter.yiri2target(
message, self.api_client
)
lark_message = await self.message_converter.yiri2target(message, self.api_client)
final_content = {
'zh_cn': {
@@ -480,9 +437,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
.build()
)
response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(
request
)
response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(request)
if not response.success():
raise Exception(
@@ -495,18 +450,14 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners.pop(event_type)

View File

@@ -29,9 +29,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
elif type(message_chain) is str:
msg_list = [platform_message.Plain(message_chain)]
else:
raise Exception(
'Unknown message type: ' + str(message_chain) + str(type(message_chain))
)
raise Exception('Unknown message type: ' + str(message_chain) + str(type(message_chain)))
nakuru_msg_list = []
@@ -63,9 +61,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
# 遍历并转换
for yiri_forward_node in yiri_forward_node_list:
try:
content_list = NakuruProjectMessageConverter.yiri2target(
yiri_forward_node.message_chain
)
content_list = NakuruProjectMessageConverter.yiri2target(yiri_forward_node.message_chain)
nakuru_forward_node = nkc.Node(
name=yiri_forward_node.sender_name,
uin=yiri_forward_node.sender_id,
@@ -87,9 +83,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
return nakuru_msg_list
@staticmethod
def target2yiri(
message_chain: typing.Any, message_id: int = -1
) -> platform_message.MessageChain:
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain:
"""将Yiri的消息链转换为YiriMirai的消息链"""
assert type(message_chain) is list
@@ -97,9 +91,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
import datetime
# 添加Source组件以标记message_id等信息
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
for component in message_chain:
if type(component) is nkc.Plain:
yiri_msg_list.append(platform_message.Plain(text=component.text))
@@ -130,9 +122,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
@staticmethod
def target2yiri(event: typing.Any) -> platform_events.Event:
yiri_chain = NakuruProjectMessageConverter.target2yiri(
event.message, event.message_id
)
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
if type(event) is nakuru.FriendMessage: # 私聊消息事件
return platform_events.FriendMessage(
sender=platform_entities.Friend(
@@ -206,9 +196,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
):
task = None
converted_msg = (
self.message_converter.yiri2target(message) if not converted else message
)
converted_msg = self.message_converter.yiri2target(message) if not converted else message
# 检查是否有转发消息
has_forward = False
@@ -250,13 +238,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
),
)
if type(message_source) is platform_events.GroupMessage:
await self.send_message(
'group', message_source.sender.group.id, message, converted=True
)
await self.send_message('group', message_source.sender.group.id, message, converted=True)
elif type(message_source) is platform_events.FriendMessage:
await self.send_message(
'person', message_source.sender.id, message, converted=True
)
await self.send_message('person', message_source.sender.id, message, converted=True)
else:
raise Exception('Unknown message source type: ' + str(type(message_source)))
@@ -264,17 +248,13 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
import time
# 检查是否被禁言
group_member_info = asyncio.run(
self.bot.getGroupMemberInfo(group_id, self.bot_account_id)
)
group_member_info = asyncio.run(self.bot.getGroupMemberInfo(group_id, self.bot_account_id))
return group_member_info.shut_up_timestamp > int(time.time())
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
try:
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
@@ -301,9 +281,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
@@ -312,10 +290,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
# 从本对象的监听器列表中查找并删除
target_wrapper = None
for listener in self.listener_list:
if (
listener['event_type'] == event_type
and listener['callable'] == callback
):
if listener['event_type'] == event_type and listener['callable'] == callback:
target_wrapper = listener['wrapper']
self.listener_list.remove(listener)
break
@@ -334,14 +309,8 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
import requests
resp = requests.get(
url='http://{}:{}/get_login_info'.format(
self.cfg['host'], self.cfg['http_port']
),
headers={
'Authorization': 'Bearer ' + self.cfg['token']
if 'token' in self.cfg
else ''
},
url='http://{}:{}/get_login_info'.format(self.cfg['host'], self.cfg['http_port']),
headers={'Authorization': 'Bearer ' + self.cfg['token'] if 'token' in self.cfg else ''},
timeout=5,
proxies=None,
)
@@ -349,9 +318,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
raise Exception('go-cqhttp拒绝访问请检查配置文件中nakuru适配器的配置')
self.bot_account_id = int(resp.json()['data']['user_id'])
except Exception:
raise Exception(
'获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确'
)
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
await self.bot._run()
self.ap.logger.info('运行 Nakuru 适配器')
while True:

View File

@@ -25,9 +25,7 @@ class OAMessageConverter(adapter.MessageConverter):
@staticmethod
async def target2yiri(message: str, message_id=-1):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
yiri_msg_list.append(platform_message.Plain(text=message))
chain = platform_message.MessageChain(yiri_msg_list)
@@ -39,9 +37,7 @@ class OAEventConverter(adapter.EventConverter):
@staticmethod
async def target2yiri(event: OAEvent):
if event.type == 'text':
yiri_chain = await OAMessageConverter.target2yiri(
event.message, event.message_id
)
yiri_chain = await OAMessageConverter.target2yiri(event.message, event.message_id)
friend = platform_entities.Friend(
id=event.user_id,
@@ -81,9 +77,7 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ParamNotEnoughError(
'微信公众号缺少相关配置项,请查看文档或联系管理员'
)
raise ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员')
if self.config['Mode'] == 'drop':
self.bot = OAClient(
@@ -114,28 +108,20 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
await self.bot.set_message(message_source.message_chain.message_id, content)
elif isinstance(self.bot, OAClientForLongerResponse):
from_user = message_source.sender.id
await self.bot.set_message(
from_user, message_source.message_chain.message_id, content
)
await self.bot.set_message(from_user, message_source.message_chain.message_id, content)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
def register_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
async def on_message(event: OAEvent):
self.bot_account_id = event.receiver_id
try:
return await callback(
await self.event_converter.target2yiri(event), self
)
return await callback(await self.event_converter.target2yiri(event), self)
except Exception:
traceback.print_exc()
@@ -161,8 +147,6 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
async def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -147,9 +147,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
elif type(message_chain) is str:
msg_list = [platform_message.Plain(text=message_chain)]
else:
raise Exception(
'Unknown message type: ' + str(message_chain) + str(type(message_chain))
)
raise Exception('Unknown message type: ' + str(message_chain) + str(type(message_chain)))
offcial_messages: list[dict] = []
"""
@@ -172,19 +170,13 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
if component.url is not None:
offcial_messages.append({'type': 'image', 'content': component.url})
elif component.path is not None:
offcial_messages.append(
{'type': 'file_image', 'content': component.path}
)
offcial_messages.append({'type': 'file_image', 'content': component.path})
elif type(component) is platform_message.At:
offcial_messages.append({'type': 'at', 'content': ''})
elif type(component) is platform_message.AtAll:
print(
'上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。'
)
print('上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。')
elif type(component) is platform_message.Voice:
print(
'上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。'
)
print('上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。')
elif type(component) is forward.Forward:
# 转发消息
yiri_forward_node_list = component.node_list
@@ -195,9 +187,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
message_chain = yiri_forward_node.message_chain
# 平铺
offcial_messages.extend(
OfficialMessageConverter.yiri2target(message_chain)
)
offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain))
except Exception:
import traceback
@@ -219,11 +209,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
yiri_msg_list = []
# 存id
yiri_msg_list.append(
platform_message.Source(
id=save_msg_id(message_id), time=datetime.datetime.now()
)
)
yiri_msg_list.append(platform_message.Source(id=save_msg_id(message_id), time=datetime.datetime.now()))
if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]:
yiri_msg_list.append(platform_message.At(target=bot_account_id))
@@ -239,9 +225,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
if attachment.content_type.startswith('image'):
yiri_msg_list.append(platform_message.Image(url=attachment.url))
else:
logging.warning(
'不支持的附件类型:' + attachment.content_type + ',忽略此附件。'
)
logging.warning('不支持的附件类型:' + attachment.content_type + ',忽略此附件。')
content = re.sub(r'<@!\d+>', '', str(message.content))
if content.strip() != '':
@@ -264,9 +248,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
elif event == platform_events.FriendMessage:
return botpy_message.DirectMessage
else:
raise Exception(
'未支持转换的事件类型(YiriMirai -> Official): ' + str(event)
)
raise Exception('未支持转换的事件类型(YiriMirai -> Official): ' + str(event))
def target2yiri(
self,
@@ -297,21 +279,13 @@ class OfficialEventConverter(adapter_model.EventConverter):
),
special_title='',
join_timestamp=int(
datetime.datetime.strptime(
event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
datetime.datetime.strptime(event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z').timestamp()
),
last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
)
elif isinstance(event, botpy_message.DirectMessage): # 频道私聊,转私聊事件
return platform_events.FriendMessage(
@@ -320,14 +294,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
nickname=event.author.username,
remark=event.author.username,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
)
elif isinstance(event, botpy_message.GroupMessage): # 群聊,转群聊事件
author_member_id = event.author.member_openid
@@ -347,14 +315,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
)
elif isinstance(event, botpy_message.C2CMessage): # 私聊,转私聊事件
user_id_alter = event.author.user_openid
@@ -365,14 +327,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
nickname=user_id_alter,
remark=user_id_alter,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
)
@@ -420,9 +376,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
self.bot = botpy.Client(intents=intents)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
message_list = self.message_converter.yiri2target(message)
for msg in message_list:
@@ -468,22 +422,16 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
if quote_origin:
args['message_reference'] = botpy_message_type.Reference(
message_id=cached_message_ids[
str(message_source.message_chain.message_id)
]
message_id=cached_message_ids[str(message_source.message_chain.message_id)]
)
if isinstance(message_source, platform_events.GroupMessage):
args['channel_id'] = str(message_source.sender.group.id)
args['msg_id'] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_message(**args)
elif isinstance(message_source, platform_events.FriendMessage):
args['guild_id'] = str(message_source.sender.id)
args['msg_id'] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_dms(**args)
elif isinstance(message_source, OfficialGroupMessage):
if 'file_image' in args: # 暂不支持发送文件图片
@@ -502,9 +450,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
args['media'] = uploadMedia
args['msg_type'] = 7
args['msg_id'] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
args['msg_seq'] = self.group_msg_seq
self.group_msg_seq += 1
@@ -523,9 +469,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
args['media'] = uploadMedia
args['msg_type'] = 7
args['msg_id'] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
args['msg_seq'] = self.c2c_msg_seq
self.c2c_msg_seq += 1
@@ -538,9 +482,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
try:
@@ -563,9 +505,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
delattr(self.bot, event_handler_mapping[event_type])

View File

@@ -35,13 +35,9 @@ class QQOfficialMessageConverter(adapter.MessageConverter):
@staticmethod
async def target2yiri(message: str, message_id: str, pic_url: str, content_type):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
if pic_url is not None:
base64_url = await image.get_qq_official_image_base64(
pic_url=pic_url, content_type=content_type
)
base64_url = await image.get_qq_official_image_base64(pic_url=pic_url, content_type=content_type)
yiri_msg_list.append(platform_message.Image(base64=base64_url))
yiri_msg_list.append(platform_message.Plain(text=message))
@@ -75,11 +71,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
return platform_events.FriendMessage(
sender=friend,
message_chain=yiri_chain,
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
source_platform_object=event,
)
@@ -89,9 +81,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
nickname=event.t,
remark='',
)
return platform_events.FriendMessage(
sender=friend, message_chain=yiri_chain, source_platform_object=event
)
return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, source_platform_object=event)
if event.t == 'GROUP_AT_MESSAGE_CREATE':
yiri_chain.insert(0, platform_message.At(target='justbot'))
@@ -109,11 +99,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
last_speak_timestamp=0,
mute_time_remaining=0,
)
time = int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
)
time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp())
return platform_events.GroupMessage(
sender=sender,
message_chain=yiri_chain,
@@ -136,11 +122,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
last_speak_timestamp=0,
mute_time_remaining=0,
)
time = int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
)
time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp())
return platform_events.GroupMessage(
sender=sender,
message_chain=yiri_chain,
@@ -167,9 +149,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ParamNotEnoughError(
'QQ官方机器人缺少相关配置项请查看文档或联系管理员'
)
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员')
self.bot = QQOfficialClient(
app_id=config['appid'],
@@ -229,24 +209,18 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
qq_official_event.d_id,
)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: QQOfficialEvent):
self.bot_account_id = 'justbot'
try:
return await callback(
await self.event_converter.target2yiri(event), self
)
return await callback(await self.event_converter.target2yiri(event), self)
except Exception:
traceback.print_exc()
@@ -274,8 +248,6 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -11,37 +11,32 @@ from pkg.platform.types import events as platform_events, message as platform_me
from libs.slack_api.slackevent import SlackEvent
from pkg.core import app
from .. import adapter
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError
from ...utils import image
class SlackMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(message_chain:platform_message.MessageChain):
async def yiri2target(message_chain: platform_message.MessageChain):
content_list = []
for msg in message_chain:
if type(msg) is platform_message.Plain:
content_list.append({
"content":msg.text,
})
content_list.append(
{
'content': msg.text,
}
)
return content_list
@staticmethod
async def target2yiri(message:str,message_id:str,pic_url:str,bot:SlackClient):
async def target2yiri(message: str, message_id: str, pic_url: str, bot: SlackClient):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id,time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
if pic_url is not None:
base64_url = await image.get_slack_image_to_base64(pic_url=pic_url,bot_token=bot.bot_token)
yiri_msg_list.append(
platform_message.Image(base64=base64_url)
)
base64_url = await image.get_slack_image_to_base64(pic_url=pic_url, bot_token=bot.bot_token)
yiri_msg_list.append(platform_message.Image(base64=base64_url))
yiri_msg_list.append(platform_message.Plain(text=message))
chain = platform_message.MessageChain(yiri_msg_list)
@@ -49,55 +44,43 @@ class SlackMessageConverter(adapter.MessageConverter):
class SlackEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(event:platform_events.MessageEvent) -> SlackEvent:
async def yiri2target(event: platform_events.MessageEvent) -> SlackEvent:
return event.source_platform_object
@staticmethod
async def target2yiri(event:SlackEvent,bot:SlackClient):
async def target2yiri(event: SlackEvent, bot: SlackClient):
yiri_chain = await SlackMessageConverter.target2yiri(
message=event.text,message_id=event.message_id,pic_url=event.pic_url,bot=bot
message=event.text, message_id=event.message_id, pic_url=event.pic_url, bot=bot
)
if event.type == 'channel':
yiri_chain.insert(0, platform_message.At(target="SlackBot"))
yiri_chain.insert(0, platform_message.At(target='SlackBot'))
sender = platform_entities.GroupMember(
id = event.user_id,
member_name= str(event.sender_name),
permission= 'MEMBER',
group = platform_entities.Group(
id = event.channel_id,
name = 'MEMBER',
permission= platform_entities.Permission.Member
id=event.user_id,
member_name=str(event.sender_name),
permission='MEMBER',
group=platform_entities.Group(
id=event.channel_id, name='MEMBER', permission=platform_entities.Permission.Member
),
special_title='',
join_timestamp=0,
last_speak_timestamp=0,
mute_time_remaining=0
mute_time_remaining=0,
)
time = int(datetime.datetime.utcnow().timestamp())
return platform_events.GroupMessage(
sender = sender,
message_chain=yiri_chain,
time = time,
source_platform_object=event
sender=sender, message_chain=yiri_chain, time=time, source_platform_object=event
)
if event.type == 'im':
return platform_events.FriendMessage(
sender=platform_entities.Friend(
id=event.user_id,
nickname = event.sender_name,
remark=""
),
message_chain = yiri_chain,
time = float(datetime.datetime.now().timestamp()),
sender=platform_entities.Friend(id=event.user_id, nickname=event.sender_name, remark=''),
message_chain=yiri_chain,
time=float(datetime.datetime.now().timestamp()),
source_platform_object=event,
)
class SlackAdapter(adapter.MessagePlatformAdapter):
@@ -108,21 +91,18 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
event_converter: SlackEventConverter = SlackEventConverter()
config: dict
def __init__(self,config:dict,ap:app.Application):
def __init__(self, config: dict, ap: app.Application):
self.config = config
self.ap = ap
required_keys = [
"bot_token",
"signing_secret",
'bot_token',
'signing_secret',
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ParamNotEnoughError("Slack机器人缺少相关配置项请查看文档或联系管理员")
raise ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员')
self.bot = SlackClient(
bot_token=self.config["bot_token"],
signing_secret=self.config["signing_secret"]
)
self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'])
async def reply_message(
self,
@@ -130,52 +110,40 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
message: platform_message.MessageChain,
quote_origin: bool = False,
):
slack_event = await SlackEventConverter.yiri2target(
message_source
)
slack_event = await SlackEventConverter.yiri2target(message_source)
content_list = await SlackMessageConverter.yiri2target(message)
content_list = await SlackMessageConverter.yiri2target(message)
for content in content_list:
if slack_event.type == 'channel':
await self.bot.send_message_to_channel(
content['content'],slack_event.channel_id
)
await self.bot.send_message_to_channel(content['content'], slack_event.channel_id)
if slack_event.type == 'im':
await self.bot.send_message_to_one(
content['content'],slack_event.user_id
)
await self.bot.send_message_to_one(content['content'], slack_event.user_id)
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
content_list = await SlackMessageConverter.yiri2target(message)
for content in content_list:
if target_type == 'person':
await self.bot.send_message_to_one(content['content'],target_id)
await self.bot.send_message_to_one(content['content'], target_id)
if target_type == 'group':
await self.bot.send_message_to_channel(content['content'],target_id)
await self.bot.send_message_to_channel(content['content'], target_id)
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event:SlackEvent):
async def on_message(event: SlackEvent):
self.bot_account_id = 'SlackBot'
try:
return await callback(
await self.event_converter.target2yiri(event,self.bot),self
)
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
except:
traceback.print_exc()
if event_type == platform_events.FriendMessage:
self.bot.on_message("im")(on_message)
elif event_type == platform_events.GroupMessage:
self.bot.on_message("channel")(on_message)
if event_type == platform_events.FriendMessage:
self.bot.on_message('im')(on_message)
elif event_type == platform_events.GroupMessage:
self.bot.on_message('channel')(on_message)
async def run_async(self):
async def shutdown_trigger_placeholder():
@@ -183,8 +151,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
await asyncio.sleep(1)
await self.bot.run_task(
host="0.0.0.0",
port=self.config["port"],
host='0.0.0.0',
port=self.config['port'],
shutdown_trigger=shutdown_trigger_placeholder,
)
@@ -197,8 +165,3 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -21,9 +21,7 @@ from ..types import entities as platform_entities
class TelegramMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain, bot: telegram.Bot
) -> list[dict]:
async def yiri2target(message_chain: platform_message.MessageChain, bot: telegram.Bot) -> list[dict]:
components = []
for component in message_chain:
@@ -45,18 +43,12 @@ class TelegramMessageConverter(adapter.MessageConverter):
components.append({'type': 'photo', 'photo': photo_bytes})
elif isinstance(component, platform_message.Forward):
for node in component.node_list:
components.extend(
await TelegramMessageConverter.yiri2target(
node.message_chain, bot
)
)
components.extend(await TelegramMessageConverter.yiri2target(node.message_chain, bot))
return components
@staticmethod
async def target2yiri(
message: telegram.Message, bot: telegram.Bot, bot_account_id: str
):
async def target2yiri(message: telegram.Message, bot: telegram.Bot, bot_account_id: str):
message_components = []
def parse_message_text(text: str) -> list[platform_message.MessageComponent]:
@@ -103,9 +95,7 @@ class TelegramEventConverter(adapter.EventConverter):
@staticmethod
async def target2yiri(event: Update, bot: telegram.Bot, bot_account_id: str):
lb_message = await TelegramMessageConverter.target2yiri(
event.message, bot, bot_account_id
)
lb_message = await TelegramMessageConverter.target2yiri(event.message, bot, bot_account_id)
if event.effective_chat.type == 'private':
return platform_events.FriendMessage(
@@ -166,9 +156,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
return
try:
lb_event = await self.event_converter.target2yiri(
update, self.bot, self.bot_account_id
)
lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id)
await self.listeners[type(lb_event)](lb_event, self)
except Exception:
print(traceback.format_exc())
@@ -176,14 +164,10 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
self.application = ApplicationBuilder().token(self.config['token']).build()
self.bot = self.application.bot
self.application.add_handler(
MessageHandler(
filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback
)
MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback)
)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
async def reply_message(
@@ -210,9 +194,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
if self.config['markdown_card'] is True:
args['parse_mode'] = 'MarkdownV2'
if quote_origin:
args['reply_to_message_id'] = (
message_source.source_platform_object.message.id
)
args['reply_to_message_id'] = message_source.source_platform_object.message.id
await self.bot.send_message(**args)
@@ -222,18 +204,14 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners.pop(event_type)

View File

@@ -18,9 +18,7 @@ from ...utils import image
class WecomMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain, bot: WecomClient
):
async def yiri2target(message_chain: platform_message.MessageChain, bot: WecomClient):
content_list = []
for msg in message_chain:
@@ -40,13 +38,7 @@ class WecomMessageConverter(adapter.MessageConverter):
)
elif type(msg) is platform_message.Forward:
for node in msg.node_list:
content_list.extend(
(
await WecomMessageConverter.yiri2target(
node.message_chain, bot
)
)
)
content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot)))
else:
content_list.append(
{
@@ -60,9 +52,7 @@ class WecomMessageConverter(adapter.MessageConverter):
@staticmethod
async def target2yiri(message: str, message_id: int = -1):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
yiri_msg_list.append(platform_message.Plain(text=message))
chain = platform_message.MessageChain(yiri_msg_list)
@@ -72,15 +62,9 @@ class WecomMessageConverter(adapter.MessageConverter):
@staticmethod
async def target2yiri_image(picurl: str, message_id: int = -1):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
image_base64, image_format = await image.get_wecom_image_base64(pic_url=picurl)
yiri_msg_list.append(
platform_message.Image(
base64=f'data:image/{image_format};base64,{image_base64}'
)
)
yiri_msg_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
chain = platform_message.MessageChain(yiri_msg_list)
return chain
@@ -88,9 +72,7 @@ class WecomMessageConverter(adapter.MessageConverter):
class WecomEventConverter:
@staticmethod
async def yiri2target(
event: platform_events.Event, bot_account_id: int, bot: WecomClient
) -> WecomEvent:
async def yiri2target(event: platform_events.Event, bot_account_id: int, bot: WecomClient) -> WecomEvent:
# only for extracting user information
if type(event) is platform_events.GroupMessage:
@@ -124,18 +106,14 @@ class WecomEventConverter:
"""
# 转换消息链
if event.type == 'text':
yiri_chain = await WecomMessageConverter.target2yiri(
event.message, event.message_id
)
yiri_chain = await WecomMessageConverter.target2yiri(event.message, event.message_id)
friend = platform_entities.Friend(
id=f'u{event.user_id}',
nickname=str(event.agent_id),
remark='',
)
return platform_events.FriendMessage(
sender=friend, message_chain=yiri_chain, time=event.timestamp
)
return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, time=event.timestamp)
elif event.type == 'image':
friend = platform_entities.Friend(
id=f'u{event.user_id}',
@@ -143,13 +121,9 @@ class WecomEventConverter:
remark='',
)
yiri_chain = await WecomMessageConverter.target2yiri_image(
picurl=event.picurl, message_id=event.message_id
)
yiri_chain = await WecomMessageConverter.target2yiri_image(picurl=event.picurl, message_id=event.message_id)
return platform_events.FriendMessage(
sender=friend, message_chain=yiri_chain, time=event.timestamp
)
return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, time=event.timestamp)
class WecomAdapter(adapter.MessagePlatformAdapter):
@@ -190,26 +164,18 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
message: platform_message.MessageChain,
quote_origin: bool = False,
):
Wecom_event = await WecomEventConverter.yiri2target(
message_source, self.bot_account_id, self.bot
)
Wecom_event = await WecomEventConverter.yiri2target(message_source, self.bot_account_id, self.bot)
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
fixed_user_id = Wecom_event.user_id
# 删掉开头的u
fixed_user_id = fixed_user_id[1:]
for content in content_list:
if content['type'] == 'text':
await self.bot.send_private_msg(
fixed_user_id, Wecom_event.agent_id, content['content']
)
await self.bot.send_private_msg(fixed_user_id, Wecom_event.agent_id, content['content'])
elif content['type'] == 'image':
await self.bot.send_image(
fixed_user_id, Wecom_event.agent_id, content['media_id']
)
await self.bot.send_image(fixed_user_id, Wecom_event.agent_id, content['media_id'])
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""企业微信目前只有发送给个人的方法,
构造target_id的方式为前半部分为账户id后半部分为agent_id,中间使用“|”符号隔开。
"""
@@ -220,25 +186,19 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
if target_type == 'person':
for content in content_list:
if content['type'] == 'text':
await self.bot.send_private_msg(
user_id, agent_id, content['content']
)
await self.bot.send_private_msg(user_id, agent_id, content['content'])
if content['type'] == 'image':
await self.bot.send_image(user_id, agent_id, content['media'])
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: WecomEvent):
self.bot_account_id = event.receiver_id
try:
return await callback(
await self.event_converter.target2yiri(event), self
)
return await callback(await self.event_converter.target2yiri(event), self)
except Exception:
traceback.print_exc()
@@ -265,8 +225,6 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
async def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -11,49 +11,47 @@ from pkg.platform.types import events as platform_events, message as platform_me
from libs.wecom_customer_service_api.wecomcsevent import WecomCSEvent
from pkg.core import app
from .. import adapter
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError
from ...utils import image
class WecomMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain, bot: WecomCSClient
):
async def yiri2target(message_chain: platform_message.MessageChain, bot: WecomCSClient):
content_list = []
for msg in message_chain:
if type(msg) is platform_message.Plain:
content_list.append({
"type": "text",
"content": msg.text,
})
content_list.append(
{
'type': 'text',
'content': msg.text,
}
)
elif type(msg) is platform_message.Image:
content_list.append({
"type": "image",
"media_id": await bot.get_media_id(msg),
})
content_list.append(
{
'type': 'image',
'media_id': await bot.get_media_id(msg),
}
)
elif type(msg) is platform_message.Forward:
for node in msg.node_list:
content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot)))
else:
content_list.append({
"type": "text",
"content": str(msg),
})
content_list.append(
{
'type': 'text',
'content': str(msg),
}
)
return content_list
@staticmethod
async def target2yiri(message: str, message_id: int = -1):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
yiri_msg_list.append(platform_message.Plain(text=message))
chain = platform_message.MessageChain(yiri_msg_list)
@@ -63,21 +61,16 @@ class WecomMessageConverter(adapter.MessageConverter):
@staticmethod
async def target2yiri_image(picurl: str, message_id: int = -1):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
yiri_msg_list.append(platform_message.Image(base64=picurl))
chain = platform_message.MessageChain(yiri_msg_list)
return chain
class WecomEventConverter:
@staticmethod
async def yiri2target(
event: platform_events.Event, bot_account_id: int, bot: WecomCSClient
) -> WecomCSEvent:
async def yiri2target(event: platform_events.Event, bot_account_id: int, bot: WecomCSClient) -> WecomCSEvent:
# only for extracting user information
if type(event) is platform_events.GroupMessage:
@@ -98,29 +91,25 @@ class WecomEventConverter:
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
"""
# 转换消息链
if event.type == "text":
yiri_chain = await WecomMessageConverter.target2yiri(
event.message, event.message_id
)
if event.type == 'text':
yiri_chain = await WecomMessageConverter.target2yiri(event.message, event.message_id)
friend = platform_entities.Friend(
id=f"u{event.user_id}",
id=f'u{event.user_id}',
nickname=str(event.user_id),
remark="",
remark='',
)
return platform_events.FriendMessage(
sender=friend, message_chain=yiri_chain, time=event.timestamp, source_platform_object=event
)
elif event.type == "image":
elif event.type == 'image':
friend = platform_entities.Friend(
id=f"u{event.user_id}",
id=f'u{event.user_id}',
nickname=str(event.user_id),
remark="",
remark='',
)
yiri_chain = await WecomMessageConverter.target2yiri_image(
picurl=event.picurl, message_id=event.message_id
)
yiri_chain = await WecomMessageConverter.target2yiri_image(picurl=event.picurl, message_id=event.message_id)
return platform_events.FriendMessage(
sender=friend, message_chain=yiri_chain, time=event.timestamp, source_platform_object=event
@@ -128,7 +117,6 @@ class WecomEventConverter:
class WecomCSAdapter(adapter.MessagePlatformAdapter):
bot: WecomCSClient
ap: app.Application
bot_account_id: str
@@ -142,20 +130,20 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
self.ap = ap
required_keys = [
"corpid",
"secret",
"token",
"EncodingAESKey",
'corpid',
'secret',
'token',
'EncodingAESKey',
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ParamNotEnoughError("企业微信客服缺少相关配置项,请查看文档或联系管理员")
raise ParamNotEnoughError('企业微信客服缺少相关配置项,请查看文档或联系管理员')
self.bot = WecomCSClient(
corpid=config["corpid"],
secret=config["secret"],
token=config["token"],
EncodingAESKey=config["EncodingAESKey"],
corpid=config['corpid'],
secret=config['secret'],
token=config['token'],
EncodingAESKey=config['EncodingAESKey'],
)
async def reply_message(
@@ -164,40 +152,36 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
message: platform_message.MessageChain,
quote_origin: bool = False,
):
Wecom_event = await WecomEventConverter.yiri2target(
message_source, self.bot_account_id, self.bot
)
Wecom_event = await WecomEventConverter.yiri2target(message_source, self.bot_account_id, self.bot)
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
for content in content_list:
if content["type"] == "text":
await self.bot.send_text_msg(open_kfid=Wecom_event.receiver_id,external_userid=Wecom_event.user_id,msgid=Wecom_event.message_id,content=content["content"])
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
if content['type'] == 'text':
await self.bot.send_text_msg(
open_kfid=Wecom_event.receiver_id,
external_userid=Wecom_event.user_id,
msgid=Wecom_event.message_id,
content=content['content'],
)
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: WecomCSEvent):
self.bot_account_id = event.receiver_id
try:
return await callback(
await self.event_converter.target2yiri(event), self
)
return await callback(await self.event_converter.target2yiri(event), self)
except:
traceback.print_exc()
if event_type == platform_events.FriendMessage:
self.bot.on_message("text")(on_message)
self.bot.on_message("image")(on_message)
self.bot.on_message('text')(on_message)
self.bot.on_message('image')(on_message)
elif event_type == platform_events.GroupMessage:
pass
@@ -207,8 +191,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
await asyncio.sleep(1)
await self.bot.run_task(
host="0.0.0.0",
port=self.config["port"],
host='0.0.0.0',
port=self.config['port'],
shutdown_trigger=shutdown_trigger_placeholder,
)
@@ -220,4 +204,4 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
event_type: type,
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)
return super().unregister_listener(event_type, callback)

View File

@@ -31,10 +31,7 @@ class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass):
def __repr__(self) -> str:
return (
self.__class__.__name__
+ '('
+ ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if v))
+ ')'
self.__class__.__name__ + '(' + ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)) + ')'
)
class Config:

View File

@@ -25,13 +25,7 @@ class Event(pydantic.BaseModel):
return (
self.__class__.__name__
+ '('
+ ', '.join(
(
f'{k}={repr(v)}'
for k, v in self.__dict__.items()
if k != 'type' and v
)
)
+ ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if k != 'type' and v))
+ ')'
)

View File

@@ -51,13 +51,7 @@ class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass
return (
self.__class__.__name__
+ '('
+ ', '.join(
(
f'{k}={repr(v)}'
for k, v in self.__dict__.items()
if k != 'type' and v
)
)
+ ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if k != 'type' and v))
+ ')'
)
@@ -65,14 +59,10 @@ class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass
# 解析参数列表,将位置参数转化为具名参数
parameter_names = self.__parameter_names__
if len(args) > len(parameter_names):
raise TypeError(
f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。'
)
raise TypeError(f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。')
for name, value in zip(parameter_names, args):
if name in kwargs:
raise TypeError(
f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。'
)
raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。')
kwargs[name] = value
super().__init__(**kwargs)
@@ -140,9 +130,7 @@ class MessageChain(PlatformBaseModel):
elif isinstance(msg, str):
result.append(Plain(msg))
else:
raise TypeError(
f'消息链中元素需为 dict 或 str 或 MessageComponent当前类型{type(msg)}'
)
raise TypeError(f'消息链中元素需为 dict 或 str 或 MessageComponent当前类型{type(msg)}')
return result
@pydantic.validator('__root__', always=True, pre=True)
@@ -175,9 +163,7 @@ class MessageChain(PlatformBaseModel):
def __iter__(self):
yield from self.__root__
def get_first(
self, t: typing.Type[TMessageComponent]
) -> typing.Optional[TMessageComponent]:
def get_first(self, t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]:
"""获取消息链中第一个符合类型的消息组件。"""
for component in self:
if isinstance(component, t):
@@ -191,9 +177,7 @@ class MessageChain(PlatformBaseModel):
def __getitem__(self, index: slice) -> typing.List[MessageComponent]: ...
@typing.overload
def __getitem__(
self, index: typing.Type[TMessageComponent]
) -> typing.List[TMessageComponent]: ...
def __getitem__(self, index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]: ...
@typing.overload
def __getitem__(
@@ -208,17 +192,13 @@ class MessageChain(PlatformBaseModel):
typing.Type[TMessageComponent],
typing.Tuple[typing.Type[TMessageComponent], int],
],
) -> typing.Union[
MessageComponent, typing.List[MessageComponent], typing.List[TMessageComponent]
]:
) -> typing.Union[MessageComponent, typing.List[MessageComponent], typing.List[TMessageComponent]]:
return self.get(index)
def __setitem__(
self,
key: typing.Union[int, slice],
value: typing.Union[
MessageComponent, str, typing.Iterable[typing.Union[MessageComponent, str]]
],
value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent, str]]],
):
if isinstance(value, str):
value = Plain(value)
@@ -234,9 +214,7 @@ class MessageChain(PlatformBaseModel):
def has(
self,
sub: typing.Union[
MessageComponent, typing.Type[MessageComponent], 'MessageChain', str
],
sub: typing.Union[MessageComponent, typing.Type[MessageComponent], 'MessageChain', str],
) -> bool:
"""判断消息链中:
1. 是否有某个消息组件。
@@ -271,9 +249,7 @@ class MessageChain(PlatformBaseModel):
def __len__(self) -> int:
return len(self.__root__)
def __add__(
self, other: typing.Union['MessageChain', MessageComponent, str]
) -> 'MessageChain':
def __add__(self, other: typing.Union['MessageChain', MessageComponent, str]) -> 'MessageChain':
if isinstance(other, MessageChain):
return self.__class__(self.__root__ + other.__root__)
if isinstance(other, str):
@@ -286,9 +262,7 @@ class MessageChain(PlatformBaseModel):
if isinstance(other, MessageComponent):
return self.__class__([other] + self.__root__)
if isinstance(other, str):
return self.__class__(
[typing.cast(MessageComponent, Plain(other))] + self.__root__
)
return self.__class__([typing.cast(MessageComponent, Plain(other))] + self.__root__)
return NotImplemented
def __mul__(self, other: int):
@@ -346,9 +320,7 @@ class MessageChain(PlatformBaseModel):
return self.__root__.index(x, i, j)
raise TypeError(f'类型不匹配,当前类型:{type(x)}')
def count(
self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]
) -> int:
def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int:
"""返回消息链中 x 出现的次数。
Args:
@@ -443,9 +415,7 @@ class MessageChain(PlatformBaseModel):
@classmethod
def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]):
return cls(
Plain(c) if isinstance(c, str) else c for c in itertools.chain(*args)
)
return cls(Plain(c) if isinstance(c, str) else c for c in itertools.chain(*args))
@property
def source(self) -> typing.Optional['Source']:
@@ -557,11 +527,7 @@ class Image(MessageComponent):
"""图片的 Base64 编码。"""
def __eq__(self, other):
return (
isinstance(other, Image)
and self.type == other.type
and self.uuid == other.uuid
)
return isinstance(other, Image) and self.type == other.type and self.uuid == other.uuid
def __str__(self):
return '[图片]'
@@ -818,9 +784,7 @@ class ForwardMessageNode(pydantic.BaseModel):
Returns:
ForwardMessageNode: 生成的一条消息。
"""
return ForwardMessageNode(
sender_id=sender.id, sender_name=sender.get_name(), message_chain=message
)
return ForwardMessageNode(sender_id=sender.id, sender_name=sender.get_name(), message_chain=message)
class ForwardMessageDiaplay(pydantic.BaseModel):

View File

@@ -165,9 +165,7 @@ class APIHost:
langbot_version = ''
try:
langbot_version = (
self.ap.ver_mgr.get_current_version()
) # 从updater模块获取版本号
langbot_version = self.ap.ver_mgr.get_current_version() # 从updater模块获取版本号
except Exception:
return False
@@ -237,9 +235,7 @@ class EventContext:
message_source=self.event.query.message_event, message=message_chain
)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息
Args:
@@ -247,9 +243,7 @@ class EventContext:
target_id (str): 目标ID
message (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
"""
await self.event.query.adapter.send_message(
target_type=target_type, target_id=target_id, message=message
)
await self.event.query.adapter.send_message(target_type=target_type, target_id=target_id, message=message)
def prevent_postorder(self):
"""阻止后续插件执行"""
@@ -378,8 +372,7 @@ class RuntimeContainer(pydantic.BaseModel):
'priority': self.priority,
'config_schema': self.config_schema,
'event_handlers': {
event_name.__name__: handler.__name__
for event_name, handler in self.event_handlers.items()
event_name.__name__: handler.__name__ for event_name, handler in self.event_handlers.items()
},
'tools': [
{

View File

@@ -58,9 +58,7 @@ class GitHubRepoInstaller(installer.PluginInstaller):
ssl=ssl_context, # 使用自定义SSL上下文来验证证书
) as resp:
if resp.status != 200:
raise errors.PluginInstallerError(
f'下载源码失败: {await resp.text()}'
)
raise errors.PluginInstallerError(f'下载源码失败: {await resp.text()}')
zip_resp = await resp.read()
if await aiofiles_os.path.exists('temp/' + target_path):
@@ -101,9 +99,7 @@ class GitHubRepoInstaller(installer.PluginInstaller):
):
"""安装插件"""
task_context.trace('下载插件源码...', 'install-plugin')
repo_label = await self.download_plugin_source_code(
plugin_source, 'plugins/', task_context
)
repo_label = await self.download_plugin_source_code(plugin_source, 'plugins/', task_context)
task_context.trace('安装插件依赖...', 'install-plugin')
await self.install_requirements('plugins/' + repo_label)
task_context.trace('完成.', 'install-plugin')

View File

@@ -35,16 +35,12 @@ class PluginLoader(loader.PluginLoader):
def register(
self, name: str, description: str, version: str, author: str
) -> typing.Callable[
[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]
]:
) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]:
self.ap.logger.debug(f'注册插件 {name} {version} by {author}')
container = context.RuntimeContainer(
plugin_name=name,
plugin_label=discover_engine.I18nString(en_US=name, zh_CN=name),
plugin_description=discover_engine.I18nString(
en_US=description, zh_CN=description
),
plugin_description=discover_engine.I18nString(en_US=description, zh_CN=description),
plugin_version=version,
plugin_author=author,
plugin_repository='',
@@ -64,16 +60,12 @@ class PluginLoader(loader.PluginLoader):
# 过时
# 最早将于 v3.4 版本移除
def on(
self, event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
def on(self, event: typing.Type[events.BaseEventModel]) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册过时的事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
async def handler(
plugin: context.BasePlugin, ctx: context.EventContext
) -> None:
async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None:
args = {
'host': ctx.host,
'event': ctx,
@@ -104,15 +96,9 @@ class PluginLoader(loader.PluginLoader):
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = funcschema.get_func_schema(func)
function_name = (
self._current_container.plugin_name
+ '-'
+ (func.__name__ if name is None else name)
)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
async def handler(
plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs
):
async def handler(plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs):
return func(*args, **kwargs)
llm_function = tools_entities.LLMFunction(
@@ -129,9 +115,7 @@ class PluginLoader(loader.PluginLoader):
return wrapper
def handler(
self, event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
def handler(self, event: typing.Type[events.BaseEventModel]) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
@@ -161,11 +145,7 @@ class PluginLoader(loader.PluginLoader):
return func
function_schema = funcschema.get_func_schema(func)
function_name = (
self._current_container.plugin_name
+ '-'
+ (func.__name__ if name is None else name)
)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
llm_function = tools_entities.LLMFunction(
name=function_name,
@@ -193,9 +173,7 @@ class PluginLoader(loader.PluginLoader):
else:
try:
self._current_pkg_path = 'plugins/' + path_prefix
self._current_module_path = (
'plugins/' + path_prefix + item.name + '.py'
)
self._current_module_path = 'plugins/' + path_prefix + item.name + '.py'
self._current_container = None
@@ -205,9 +183,7 @@ class PluginLoader(loader.PluginLoader):
self.plugins.append(self._current_container)
self.ap.logger.debug(f'插件 {self._current_container} 已加载')
except Exception:
self.ap.logger.error(
f'加载插件模块 {prefix + item.name} 时发生错误'
)
self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误')
traceback.print_exc()
async def load_plugins(self):

Some files were not shown because too many files have changed in this diff Show More