mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 00:06:04 +00:00
chore: adjust dir structure
This commit is contained in:
0
src/langbot/pkg/pipeline/msgtrun/__init__.py
Normal file
0
src/langbot/pkg/pipeline/msgtrun/__init__.py
Normal file
35
src/langbot/pkg/pipeline/msgtrun/msgtrun.py
Normal file
35
src/langbot/pkg/pipeline/msgtrun/msgtrun.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities
|
||||
from . import truncator
|
||||
from ...utils import importutil
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from . import truncators
|
||||
|
||||
importutil.import_modules_in_pkg(truncators)
|
||||
|
||||
|
||||
@stage.stage_class('ConversationMessageTruncator')
|
||||
class ConversationMessageTruncator(stage.PipelineStage):
|
||||
"""Conversation message truncator
|
||||
|
||||
Used to truncate the conversation message chain to adapt to the LLM message length limit.
|
||||
"""
|
||||
|
||||
trun: truncator.Truncator
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
use_method = 'round'
|
||||
|
||||
for trun in truncator.preregistered_truncators:
|
||||
if trun.name == use_method:
|
||||
self.trun = trun(self.ap)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'Unknown truncator: {use_method}')
|
||||
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
query = await self.trun.truncate(query)
|
||||
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
56
src/langbot/pkg/pipeline/msgtrun/truncator.py
Normal file
56
src/langbot/pkg/pipeline/msgtrun/truncator.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
preregistered_truncators: list[typing.Type[Truncator]] = []
|
||||
|
||||
|
||||
def truncator_class(
|
||||
name: str,
|
||||
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
|
||||
"""截断器类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 截断器名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
|
||||
assert issubclass(cls, Truncator)
|
||||
|
||||
cls.name = name
|
||||
|
||||
preregistered_truncators.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Truncator(abc.ABC):
|
||||
"""消息截断器基类"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||
"""截断
|
||||
|
||||
一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。
|
||||
请勿操作其他字段。
|
||||
"""
|
||||
pass
|
||||
30
src/langbot/pkg/pipeline/msgtrun/truncators/round.py
Normal file
30
src/langbot/pkg/pipeline/msgtrun/truncators/round.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import truncator
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@truncator.truncator_class('round')
|
||||
class RoundTruncator(truncator.Truncator):
|
||||
"""Truncate the conversation message chain to adapt to the LLM message length limit."""
|
||||
|
||||
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||
"""截断"""
|
||||
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
||||
|
||||
temp_messages = []
|
||||
|
||||
current_round = 0
|
||||
|
||||
# Traverse from back to front
|
||||
for msg in query.messages[::-1]:
|
||||
if current_round < max_round:
|
||||
temp_messages.append(msg)
|
||||
if msg.role == 'user':
|
||||
current_round += 1
|
||||
else:
|
||||
break
|
||||
|
||||
query.messages = temp_messages[::-1]
|
||||
|
||||
return query
|
||||
Reference in New Issue
Block a user