diff --git a/pkg/command/operator.py b/pkg/command/operator.py index a666f2c3..307e9fbe 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -8,18 +8,34 @@ from . import entities preregistered_operators: list[typing.Type[CommandOperator]] = [] -"""预注册算子列表。在初始化时,所有算子类会被注册到此列表中。""" +"""预注册命令算子列表。在初始化时,所有算子类会被注册到此列表中。""" def operator_class( name: str, - help: str, + help: str = "", usage: str = None, alias: list[str] = [], privilege: int=1, # 1为普通用户,2为管理员 parent_class: typing.Type[CommandOperator] = None ) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: + """命令类装饰器 + + Args: + name (str): 名称 + help (str, optional): 帮助信息. Defaults to "". + usage (str, optional): 使用说明. Defaults to None. + alias (list[str], optional): 别名. Defaults to []. + privilege (int, optional): 权限,1为普通用户可用,2为仅管理员可用. Defaults to 1. + parent_class (typing.Type[CommandOperator], optional): 父节点,若为None则为顶级命令. Defaults to None. + + Returns: + typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器 + """ + def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: + assert issubclass(cls, CommandOperator) + cls.name = name cls.alias = alias cls.help = help diff --git a/pkg/core/app.py b/pkg/core/app.py index ed035e5d..0d726a44 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -6,7 +6,7 @@ import traceback from ..platform import manager as im_mgr from ..provider.session import sessionmgr as llm_session_mgr -from ..provider.requester import modelmgr as llm_model_mgr +from ..provider.modelmgr import modelmgr as llm_model_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 78bcf1fe..8bf1ff2e 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -9,7 +9,7 @@ import pydantic import mirai from ..provider import entities as llm_entities -from ..provider.requester import entities +from ..provider.modelmgr import entities from ..provider.sysprompt import entities as sysprompt_entities from ..provider.tools import entities as tools_entities from ..platform import adapter as msadapter diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 24bef7cf..552f5611 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -10,7 +10,7 @@ from ...pipeline import pool, controller, stagemgr from ...plugin import manager as plugin_mgr from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr -from ...provider.requester import modelmgr as llm_model_mgr +from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr from ...platform import manager as im_mgr diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 92157bdd..fee2cd3f 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -7,7 +7,7 @@ from ...core import app from .. import stage, entities, stagemgr from ...core import entities as core_entities from ...config import manager as cfg_mgr -from . import filter, entities as filter_entities +from . import filter as filter_model, entities as filter_entities from .filters import cntignore, banwords, baiduexamine @@ -16,20 +16,29 @@ from .filters import cntignore, banwords, baiduexamine class ContentFilterStage(stage.PipelineStage): """内容过滤阶段""" - filter_chain: list[filter.ContentFilter] + filter_chain: list[filter_model.ContentFilter] def __init__(self, ap: app.Application): self.filter_chain = [] super().__init__(ap) async def initialize(self): - self.filter_chain.append(cntignore.ContentIgnore(self.ap)) + + filters_required = [ + "content-filter" + ] if self.ap.pipeline_cfg.data['check-sensitive-words']: - self.filter_chain.append(banwords.BanWordFilter(self.ap)) - + filters_required.append("ban-word-filter") + if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: - self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) + filters_required.append("baidu-cloud-examine") + + for filter in filter_model.preregistered_filters: + if filter.name in filters_required: + self.filter_chain.append( + filter(self.ap) + ) for filter in self.filter_chain: await filter.initialize() diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 57792145..23471392 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -1,12 +1,42 @@ # 内容过滤器的抽象类 from __future__ import annotations import abc +import typing from ...core import app from . import entities +preregistered_filters: list[typing.Type[ContentFilter]] = [] + + +def filter_class( + name: str +) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: + """内容过滤器类装饰器 + + Args: + name (str): 过滤器名称 + + Returns: + typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器 + """ + def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]: + assert issubclass(cls, ContentFilter) + + cls.name = name + + preregistered_filters.append(cls) + + return cls + + return decorator + + class ContentFilter(metaclass=abc.ABCMeta): + """内容过滤器抽象类""" + + name: str ap: app.Application diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index f72fe960..8c5b77cd 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" +@filter_model.filter_class("baidu-cloud-examine") class BaiduCloudExamine(filter_model.ContentFilter): """百度云内容审核""" diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 587f81c3..9391971c 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -6,6 +6,7 @@ from .. import entities from ....config import manager as cfg_mgr +@filter_model.filter_class("ban-word-filter") class BanWordFilter(filter_model.ContentFilter): """根据内容禁言""" diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index 92fe94e8..781f6397 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -5,6 +5,7 @@ from .. import entities from .. import filter as filter_model +@filter_model.filter_class("content-ignore") class ContentIgnore(filter_model.ContentFilter): """根据内容忽略消息""" diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 2962ae28..2095845d 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -45,11 +45,14 @@ class LongTextProcessStage(stage.PipelineStage): self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" - - if config['strategy'] == 'image': - self.strategy_impl = image.Text2ImageStrategy(self.ap) - elif config['strategy'] == 'forward': - self.strategy_impl = forward.ForwardComponentStrategy(self.ap) + + for strategy_cls in strategy.preregistered_strategies: + if strategy_cls.name == config['strategy']: + self.strategy_impl = strategy_cls(self.ap) + break + else: + raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略") + await self.strategy_impl.initialize() async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index cfab49d9..4a790313 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -36,6 +36,7 @@ class Forward(MessageComponent): return '[聊天记录]' +@strategy_model.strategy_class("forward") class ForwardComponentStrategy(strategy_model.LongTextStrategy): async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index af34f4e6..f96f03c5 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -15,6 +15,7 @@ from .. import strategy as strategy_model from ....core import entities as core_entities +@strategy_model.strategy_class("image") class Text2ImageStrategy(strategy_model.LongTextStrategy): text_render_font: ImageFont.FreeTypeFont diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index a1f8a94f..296c5b4c 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -9,7 +9,30 @@ from ...core import app from ...core import entities as core_entities +preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] + + +def strategy_class( + name: str +) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: + def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]: + assert issubclass(cls, LongTextStrategy) + + cls.name = name + + preregistered_strategies.append(cls) + + return cls + + return decorator + + class LongTextStrategy(metaclass=abc.ABCMeta): + """长文本处理策略抽象类 + """ + + name: str + ap: app.Application def __init__(self, ap: app.Application): diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index c0eb92d6..cedc030f 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -51,28 +51,6 @@ class PreProcessor(stage.PipelineStage): query.prompt.messages = event_ctx.event.default_prompt query.messages = event_ctx.event.prompt - # 根据模型max_tokens剪裁 - max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens']) - - test_messages = query.prompt.messages + query.messages + [query.user_message] - - while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens: - # 前文都pop完了,还是大于max_tokens,由于prompt和user_messages不能删减,报错 - if len(query.prompt.messages) == 0: - return entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query, - user_notice='输入内容过长,请减少情景预设或者输入内容长度', - console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项(但不能超过所用模型最大tokens数)' - ) - - query.messages.pop(0) # pop第一个肯定是role=user的 - # 继续pop到第二个role=user前一个 - while len(query.messages) > 0 and query.messages[0].role != 'user': - query.messages.pop(0) - - test_messages = query.prompt.messages + query.messages + [query.user_message] - return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index b3e8fa18..33dedb04 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -21,8 +21,6 @@ class ChatMessageHandler(handler.MessageHandler): ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理 """ - # 取session - # 取conversation # 调API # 生成器 diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index b6d9ba7b..448ae384 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -1,11 +1,26 @@ from __future__ import annotations import abc +import typing from ...core import app +preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] + +def algo_class(name: str): + + def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]: + cls.name = name + preregistered_algos.append(cls) + return cls + + return decorator + + class ReteLimitAlgo(metaclass=abc.ABCMeta): + name: str = None + ap: app.Application def __init__(self, ap: app.Application): diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index bb69b0dd..aa380291 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -19,6 +19,7 @@ class SessionContainer: self.records = {} +@algo.algo_class("fixwin") class FixedWindowAlgo(algo.ReteLimitAlgo): containers_lock: asyncio.Lock diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index a9e29799..f43c8b06 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -16,7 +16,19 @@ class RateLimit(stage.PipelineStage): algo: algo.ReteLimitAlgo async def initialize(self): - self.algo = fixedwin.FixedWindowAlgo(self.ap) + + algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo'] + + algo_class = None + + for algo_cls in algo.preregistered_algos: + if algo_cls.name == algo_name: + algo_class = algo_cls + break + else: + raise ValueError(f'未知的限速算法: {algo_name}') + + self.algo = algo_class(self.ap) await self.algo.initialize() async def process( diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 8f418729..d795d056 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -21,15 +21,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): async def initialize(self): """初始化检查器 """ - self.rule_matchers = [ - atbot.AtBotRule(self.ap), - prefix.PrefixRule(self.ap), - regexp.RegExpRule(self.ap), - random.RandomRespRule(self.ap), - ] - for rule_matcher in self.rule_matchers: - await rule_matcher.initialize() + self.rule_matchers = [] + + for rule_matcher in rule.preregisetered_rules: + rule_inst = rule_matcher(self.ap) + await rule_inst.initialize() + self.rule_matchers.append(rule_inst) async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py index cde9ec3d..bfab4152 100644 --- a/pkg/pipeline/resprule/rule.py +++ b/pkg/pipeline/resprule/rule.py @@ -1,5 +1,6 @@ from __future__ import annotations import abc +import typing import mirai @@ -7,9 +8,20 @@ from ...core import app, entities as core_entities from . import entities +preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] + +def rule_class(name: str): + def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]: + cls.name = name + preregisetered_rules.append(cls) + return cls + return decorator + + class GroupRespondRule(metaclass=abc.ABCMeta): """群组响应规则的抽象类 """ + name: str ap: app.Application diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py index 692bee72..293cfd96 100644 --- a/pkg/pipeline/resprule/rules/atbot.py +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -7,6 +7,7 @@ from .. import entities from ....core import entities as core_entities +@rule_model.rule_class("at-bot") class AtBotRule(rule_model.GroupRespondRule): async def match( diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py index 1b61c138..99dcd4f9 100644 --- a/pkg/pipeline/resprule/rules/prefix.py +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -5,6 +5,7 @@ from .. import entities from ....core import entities as core_entities +@rule_model.rule_class("prefix") class PrefixRule(rule_model.GroupRespondRule): async def match( diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index 185e03ec..80acf6a5 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -7,6 +7,7 @@ from .. import entities from ....core import entities as core_entities +@rule_model.rule_class("random") class RandomRespRule(rule_model.GroupRespondRule): async def match( diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py index 4e39d432..aaa46449 100644 --- a/pkg/pipeline/resprule/rules/regexp.py +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -7,6 +7,7 @@ from .. import entities from ....core import entities as core_entities +@rule_model.rule_class("regexp") class RegExpRule(rule_model.GroupRespondRule): async def match( diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index 38c31fe2..5ce1db18 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -22,6 +22,8 @@ def adapter_class( class MessageSourceAdapter(metaclass=abc.ABCMeta): + """消息平台适配器基类""" + name: str bot_account_id: int @@ -40,7 +42,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta): target_id: str, message: mirai.MessageChain ): - """发送消息 + """主动发送消息 Args: target_type (str): 目标类型,`person`或`group` diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 3d73c198..7b40f2ab 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -163,25 +163,6 @@ class PlatformManager: quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False ) - # 通知系统管理员 - # TODO delete - # async def notify_admin(self, message: str): - # await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))])) - - # async def notify_admin_message_chain(self, message: mirai.MessageChain): - # if self.ap.system_cfg.data['admin-sessions'] != []: - - # admin_list = [] - # for admin in self.ap.system_cfg.data['admin-sessions']: - # admin_list.append(admin) - - # for adm in admin_list: - # self.adapter.send_message( - # adm.split("_")[0], - # adm.split("_")[1], - # message - # ) - async def run(self): try: tasks = [] diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 0a419a06..0b3b8c09 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): msg_list = message_chain.__root__ elif type(message_chain) is list: msg_list = message_chain + elif type(message_chain) is str: + msg_list = [mirai.Plain(message_chain)] else: raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index 6d74d0ea..313249a0 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -89,6 +89,8 @@ class OfficialMessageConverter(adapter_model.MessageConverter): msg_list = message_chain.__root__ elif type(message_chain) is list: msg_list = message_chain + elif type(message_chain) is str: + msg_list = [mirai.Plain(text=message_chain)] else: raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) diff --git a/pkg/provider/requester/__init__.py b/pkg/provider/modelmgr/__init__.py similarity index 100% rename from pkg/provider/requester/__init__.py rename to pkg/provider/modelmgr/__init__.py diff --git a/pkg/provider/requester/api.py b/pkg/provider/modelmgr/api.py similarity index 65% rename from pkg/provider/requester/api.py rename to pkg/provider/modelmgr/api.py index 88ba78cd..da362468 100644 --- a/pkg/provider/requester/api.py +++ b/pkg/provider/modelmgr/api.py @@ -7,9 +7,23 @@ from ...core import app from ...core import entities as core_entities from .. import entities as llm_entities + +preregistered_requesters: list[typing.Type[LLMAPIRequester]] = [] + +def requester_class(name: str): + + def decorator(cls: typing.Type[LLMAPIRequester]) -> typing.Type[LLMAPIRequester]: + cls.name = name + preregistered_requesters.append(cls) + return cls + + return decorator + + class LLMAPIRequester(metaclass=abc.ABCMeta): """LLM API请求器 """ + name: str = None ap: app.Application diff --git a/pkg/provider/requester/apis/__init__.py b/pkg/provider/modelmgr/apis/__init__.py similarity index 100% rename from pkg/provider/requester/apis/__init__.py rename to pkg/provider/modelmgr/apis/__init__.py diff --git a/pkg/provider/requester/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py similarity index 94% rename from pkg/provider/requester/apis/chatcmpl.py rename to pkg/provider/modelmgr/apis/chatcmpl.py index 2d520017..4965acf7 100644 --- a/pkg/provider/requester/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -17,6 +17,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities +@api.requester_class("openai-chat-completion") class OpenAIChatCompletion(api.LLMAPIRequester): """OpenAI ChatCompletion API 请求器""" @@ -133,7 +134,10 @@ class OpenAIChatCompletion(api.LLMAPIRequester): except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: - raise errors.RequesterError(f'请求错误: {e.message}') + if 'context_length_exceeded' in e.message: + raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') + else: + raise errors.RequesterError(f'请求参数错误: {e.message}') except openai.AuthenticationError as e: raise errors.RequesterError(f'无效的 api-key: {e.message}') except openai.NotFoundError as e: diff --git a/pkg/provider/requester/entities.py b/pkg/provider/modelmgr/entities.py similarity index 76% rename from pkg/provider/requester/entities.py rename to pkg/provider/modelmgr/entities.py index d4c51d6f..277f125a 100644 --- a/pkg/provider/requester/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -5,7 +5,7 @@ import typing import pydantic from . import api -from . import token, tokenizer +from . import token class LLMModelInfo(pydantic.BaseModel): @@ -19,11 +19,7 @@ class LLMModelInfo(pydantic.BaseModel): requester: api.LLMAPIRequester - tokenizer: 'tokenizer.LLMTokenizer' - tool_call_supported: typing.Optional[bool] = False - max_tokens: typing.Optional[int] = 2048 - class Config: arbitrary_types_allowed = True diff --git a/pkg/provider/requester/errors.py b/pkg/provider/modelmgr/errors.py similarity index 100% rename from pkg/provider/requester/errors.py rename to pkg/provider/modelmgr/errors.py diff --git a/pkg/provider/requester/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py similarity index 76% rename from pkg/provider/requester/modelmgr.py rename to pkg/provider/modelmgr/modelmgr.py index e1a48bc2..a91c3110 100644 --- a/pkg/provider/requester/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -3,9 +3,8 @@ from __future__ import annotations from . import entities from ...core import app -from .apis import chatcmpl from . import token -from .tokenizers import tiktoken +from .apis import chatcmpl class ModelManager: @@ -30,9 +29,7 @@ class ModelManager: async def initialize(self): openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap) await openai_chat_completion.initialize() - openai_token_mgr = token.TokenManager(self.ap, list(self.ap.provider_cfg.data['openai-config']['api-keys'])) - - tiktoken_tokenizer = tiktoken.Tiktoken(self.ap) + openai_token_mgr = token.TokenManager("openai", list(self.ap.provider_cfg.data['openai-config']['api-keys'])) model_list = [ entities.LLMModelInfo( @@ -40,48 +37,36 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=4096 ), entities.LLMModelInfo( name="gpt-3.5-turbo-1106", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=16385 ), entities.LLMModelInfo( name="gpt-3.5-turbo-16k", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=16385 ), entities.LLMModelInfo( name="gpt-3.5-turbo-0613", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=4096 ), entities.LLMModelInfo( name="gpt-3.5-turbo-16k-0613", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=16385 ), entities.LLMModelInfo( name="gpt-3.5-turbo-0301", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=4096 ) ] @@ -93,64 +78,48 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="gpt-4-turbo-preview", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="gpt-4-1106-preview", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="gpt-4-vision-preview", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="gpt-4", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=8192 ), entities.LLMModelInfo( name="gpt-4-0613", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=8192 ), entities.LLMModelInfo( name="gpt-4-32k", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=32768 ), entities.LLMModelInfo( name="gpt-4-32k-0613", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, - tokenizer=tiktoken_tokenizer, - max_tokens=32768 ) ] @@ -163,8 +132,6 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=8192 ), entities.LLMModelInfo( name="OneAPI/chatglm_pro", @@ -172,8 +139,6 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="OneAPI/chatglm_std", @@ -181,8 +146,6 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="OneAPI/chatglm_lite", @@ -190,8 +153,6 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=128000 ), entities.LLMModelInfo( name="OneAPI/qwen-v1", @@ -199,8 +160,6 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=6000 ), entities.LLMModelInfo( name="OneAPI/qwen-plus-v1", @@ -208,8 +167,6 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=30000 ), entities.LLMModelInfo( name="OneAPI/ERNIE-Bot", @@ -217,8 +174,6 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=2000 ), entities.LLMModelInfo( name="OneAPI/ERNIE-Bot-turbo", @@ -226,8 +181,6 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=7000 ), entities.LLMModelInfo( name="OneAPI/gemini-pro", @@ -235,8 +188,6 @@ class ModelManager: token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=False, - tokenizer=tiktoken_tokenizer, - max_tokens=30720 ), ] diff --git a/pkg/provider/requester/token.py b/pkg/provider/modelmgr/token.py similarity index 100% rename from pkg/provider/requester/token.py rename to pkg/provider/modelmgr/token.py diff --git a/pkg/provider/requester/tokenizer.py b/pkg/provider/requester/tokenizer.py deleted file mode 100644 index cdd91470..00000000 --- a/pkg/provider/requester/tokenizer.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import abc -import typing - -from ...core import app -from .. import entities as llm_entities -from . import entities - - -class LLMTokenizer(metaclass=abc.ABCMeta): - """LLM分词器抽象类""" - - ap: app.Application - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - """初始化分词器 - """ - pass - - @abc.abstractmethod - async def count_token( - self, - messages: list[llm_entities.Message], - model: entities.LLMModelInfo - ) -> int: - pass diff --git a/pkg/provider/requester/tokenizers/__init__.py b/pkg/provider/requester/tokenizers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/provider/requester/tokenizers/tiktoken.py b/pkg/provider/requester/tokenizers/tiktoken.py deleted file mode 100644 index 24d2d8b6..00000000 --- a/pkg/provider/requester/tokenizers/tiktoken.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import tiktoken - -from .. import tokenizer -from ... import entities as llm_entities -from .. import entities - - -class Tiktoken(tokenizer.LLMTokenizer): - """TikToken分词器 - """ - - async def count_token( - self, - messages: list[llm_entities.Message], - model: entities.LLMModelInfo - ) -> int: - try: - encoding = tiktoken.encoding_for_model(model.name) - except KeyError: - # print("Warning: model not found. Using cl100k_base encoding.") - encoding = tiktoken.get_encoding("cl100k_base") - - num_tokens = 0 - for message in messages: - num_tokens += len(encoding.encode(message.role)) - num_tokens += len(encoding.encode(message.content if message.content is not None else '')) - num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> - return num_tokens diff --git a/pkg/provider/sysprompt/loader.py b/pkg/provider/sysprompt/loader.py index ca9e8730..9e0a6144 100644 --- a/pkg/provider/sysprompt/loader.py +++ b/pkg/provider/sysprompt/loader.py @@ -1,13 +1,27 @@ from __future__ import annotations import abc +import typing from ...core import app from . import entities +preregistered_loaders: list[typing.Type[PromptLoader]] = [] + +def loader_class(name: str): + + def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]: + cls.name = name + preregistered_loaders.append(cls) + return cls + + return decorator + + class PromptLoader(metaclass=abc.ABCMeta): """Prompt加载器抽象类 """ + name: str ap: app.Application diff --git a/pkg/provider/sysprompt/loaders/scenario.py b/pkg/provider/sysprompt/loaders/scenario.py index a559ff73..9c19d963 100644 --- a/pkg/provider/sysprompt/loaders/scenario.py +++ b/pkg/provider/sysprompt/loaders/scenario.py @@ -8,6 +8,7 @@ from .. import entities from ....provider import entities as llm_entities +@loader.loader_class("full_scenario") class ScenarioPromptLoader(loader.PromptLoader): """加载scenario目录下的json""" diff --git a/pkg/provider/sysprompt/loaders/single.py b/pkg/provider/sysprompt/loaders/single.py index 57e06ed2..3ac9c262 100644 --- a/pkg/provider/sysprompt/loaders/single.py +++ b/pkg/provider/sysprompt/loaders/single.py @@ -6,6 +6,7 @@ from .. import entities from ....provider import entities as llm_entities +@loader.loader_class("normal") class SingleSystemPromptLoader(loader.PromptLoader): """配置文件中的单条system prompt的prompt加载器 """ diff --git a/pkg/provider/sysprompt/sysprompt.py b/pkg/provider/sysprompt/sysprompt.py index eb89e8ab..c7695f5a 100644 --- a/pkg/provider/sysprompt/sysprompt.py +++ b/pkg/provider/sysprompt/sysprompt.py @@ -20,14 +20,18 @@ class PromptManager: async def initialize(self): - loader_map = { - "normal": single.SingleSystemPromptLoader, - "full_scenario": scenario.ScenarioPromptLoader - } + mode_name = self.ap.provider_cfg.data['prompt-mode'] - loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']] + loader_class = None - self.loader_inst: loader.PromptLoader = loader_cls(self.ap) + for loader_cls in loader.preregistered_loaders: + if loader_cls.name == mode_name: + loader_class = loader_cls + break + else: + raise ValueError(f'未知的 Prompt 加载器: {mode_name}') + + self.loader_inst: loader.PromptLoader = loader_class(self.ap) await self.loader_inst.initialize() await self.loader_inst.load()