feat: 消息截断器

This commit is contained in:
RockChinQ
2024-05-29 20:34:49 +08:00
parent 21fe5822f9
commit c8eb2e3376
9 changed files with 156 additions and 1 deletions

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from .. import truncator
from ....core import entities as core_entities
@truncator.truncator_class("round")
class RoundTruncator(truncator.Truncator):
"""前文回合数阶段器
"""
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
"""
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
temp_messages = []
current_round = 0
# 从后往前遍历
for msg in query.messages[::-1]:
if current_round < max_round:
temp_messages.append(msg)
if msg.role == 'user':
current_round += 1
else:
break
query.messages = temp_messages[::-1]
return query