mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
Merge pull request #1245 from RockChinQ/feat/invoke-pipelines
feat: pipeline invoking
This commit is contained in:
@@ -5,10 +5,25 @@ import datetime
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....pipeline import stagemgr
|
||||
from ....entity.persistence import pipeline as persistence_pipeline
|
||||
|
||||
|
||||
default_stage_order = [
|
||||
"GroupRespondRuleCheckStage", # 群响应规则检查
|
||||
"BanSessionCheckStage", # 封禁会话检查
|
||||
"PreContentFilterStage", # 内容过滤前置阶段
|
||||
"PreProcessor", # 预处理器
|
||||
"ConversationMessageTruncator", # 会话消息截断器
|
||||
"RequireRateLimitOccupancy", # 请求速率限制占用
|
||||
"MessageProcessor", # 处理器
|
||||
"ReleaseRateLimitOccupancy", # 释放速率限制占用
|
||||
"PostContentFilterStage", # 内容过滤后置阶段
|
||||
"ResponseWrapper", # 响应包装器
|
||||
"LongTextProcessStage", # 长文本处理
|
||||
"SendResponseBackStage", # 发送响应
|
||||
]
|
||||
|
||||
|
||||
class PipelineService:
|
||||
ap: app.Application
|
||||
|
||||
@@ -49,7 +64,7 @@ class PipelineService:
|
||||
async def create_pipeline(self, pipeline_data: dict) -> str:
|
||||
pipeline_data['uuid'] = str(uuid.uuid4())
|
||||
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
|
||||
pipeline_data['stages'] = stagemgr.stage_order.copy()
|
||||
pipeline_data['stages'] = default_stage_order.copy()
|
||||
|
||||
# TODO: 检查pipeline config是否完整
|
||||
|
||||
@@ -64,9 +79,12 @@ class PipelineService:
|
||||
return pipeline_data['uuid']
|
||||
|
||||
async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None:
|
||||
del pipeline_data['uuid']
|
||||
del pipeline_data['for_version']
|
||||
del pipeline_data['stages']
|
||||
if 'uuid' in pipeline_data:
|
||||
del pipeline_data['uuid']
|
||||
if 'for_version' in pipeline_data:
|
||||
del pipeline_data['for_version']
|
||||
if 'stages' in pipeline_data:
|
||||
del pipeline_data['stages']
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data)
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from . import entities, operator, errors
|
||||
from ..config import manager as cfg_mgr
|
||||
|
||||
# 引入所有算子以便注册
|
||||
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
|
||||
from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
|
||||
|
||||
|
||||
class CommandManager:
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="default",
|
||||
help="操作情景预设",
|
||||
usage='!default\n!default set <指定情景预设为默认>'
|
||||
)
|
||||
class DefaultOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
reply_str = "当前所有情景预设: \n\n"
|
||||
|
||||
for prompt in self.ap.prompt_mgr.get_all_prompts():
|
||||
|
||||
content = ""
|
||||
for msg in prompt.messages:
|
||||
content += f" {msg.readable_str()}\n"
|
||||
|
||||
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
|
||||
|
||||
reply_str += f"当前会话使用的是: {context.session.use_prompt_name}"
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="set",
|
||||
help="设置当前会话默认情景预设",
|
||||
parent_class=DefaultOperator
|
||||
)
|
||||
class DefaultSetOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
|
||||
else:
|
||||
prompt_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
|
||||
if prompt is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
|
||||
else:
|
||||
context.session.use_prompt_name = prompt.name
|
||||
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))
|
||||
@@ -11,16 +11,14 @@ import os
|
||||
from ..platform import manager as im_mgr
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..provider import runnermgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..config import settings as settings_mgr
|
||||
from ..audit.center import v2 as center_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import manager as plugin_mgr
|
||||
from ..pipeline import pool
|
||||
from ..pipeline import controller, stagemgr, pipelinemgr
|
||||
from ..pipeline import controller, pipelinemgr
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
|
||||
from ..persistence import mgr as persistencemgr
|
||||
from ..api.http.controller import main as http_controller
|
||||
@@ -53,12 +51,9 @@ class Application:
|
||||
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
||||
prompt_mgr: llm_prompt_mgr.PromptManager = None
|
||||
|
||||
# TODO 移动到 pipeline 里
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
runner_mgr: runnermgr.RunnerManager = None
|
||||
|
||||
settings_mgr: settings_mgr.SettingsManager = None
|
||||
|
||||
# ======= 配置管理器 =======
|
||||
@@ -100,8 +95,6 @@ class Application:
|
||||
|
||||
ctrl: controller.Controller = None
|
||||
|
||||
stage_mgr: stagemgr.StageManager = None
|
||||
|
||||
pipeline_mgr: pipelinemgr.PipelineManager = None
|
||||
|
||||
ver_mgr: version_mgr.VersionManager = None
|
||||
@@ -232,16 +225,8 @@ class Application:
|
||||
await llm_session_mgr_inst.initialize()
|
||||
self.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(self)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
self.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
self.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
runner_mgr_inst = runnermgr.RunnerManager(self)
|
||||
await runner_mgr_inst.initialize()
|
||||
self.runner_mgr = runner_mgr_inst
|
||||
case _:
|
||||
pass
|
||||
@@ -8,8 +8,7 @@ import asyncio
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ..provider import entities as llm_entities
|
||||
from ..provider.modelmgr import entities
|
||||
from ..provider.sysprompt import entities as sysprompt_entities
|
||||
from ..provider.modelmgr import entities, modelmgr, requester
|
||||
from ..provider.tools import entities as tools_entities
|
||||
from ..platform import adapter as msadapter
|
||||
from ..platform.types import message as platform_message
|
||||
@@ -57,6 +56,15 @@ class Query(pydantic.BaseModel):
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息链,platform收到的原始消息链"""
|
||||
|
||||
bot_uuid: typing.Optional[str] = None
|
||||
"""机器人UUID。"""
|
||||
|
||||
pipeline_uuid: typing.Optional[str] = None
|
||||
"""流水线UUID。"""
|
||||
|
||||
pipeline_config: typing.Optional[dict[str, typing.Any]] = None
|
||||
"""流水线配置,由 Pipeline 在运行开始时设置。"""
|
||||
|
||||
adapter: msadapter.MessagePlatformAdapter
|
||||
"""消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器"""
|
||||
|
||||
@@ -66,7 +74,7 @@ class Query(pydantic.BaseModel):
|
||||
messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""历史消息列表,由前置处理器阶段设置"""
|
||||
|
||||
prompt: typing.Optional[sysprompt_entities.Prompt] = None
|
||||
prompt: typing.Optional[llm_entities.Prompt] = None
|
||||
"""情景预设内容,由前置处理器阶段设置"""
|
||||
|
||||
user_message: typing.Optional[llm_entities.Message] = None
|
||||
@@ -75,8 +83,8 @@ class Query(pydantic.BaseModel):
|
||||
variables: typing.Optional[dict[str, typing.Any]] = None
|
||||
"""变量,由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
|
||||
|
||||
use_model: typing.Optional[entities.LLMModelInfo] = None
|
||||
"""使用的模型,由前置处理器阶段设置"""
|
||||
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
|
||||
"""使用的对话模型,由前置处理器阶段设置"""
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
||||
"""使用的函数,由前置处理器阶段设置"""
|
||||
@@ -88,7 +96,7 @@ class Query(pydantic.BaseModel):
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
|
||||
# ======= 内部保留 =======
|
||||
current_stage: "pkg.pipeline.stagemgr.StageInstContainer" = None
|
||||
current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None
|
||||
"""当前所处阶段"""
|
||||
|
||||
class Config:
|
||||
@@ -118,7 +126,7 @@ class Query(pydantic.BaseModel):
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation"""
|
||||
|
||||
prompt: sysprompt_entities.Prompt
|
||||
prompt: llm_entities.Prompt
|
||||
|
||||
messages: list[llm_entities.Message]
|
||||
|
||||
@@ -126,13 +134,16 @@ class Conversation(pydantic.BaseModel):
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
use_model: entities.LLMModelInfo
|
||||
use_llm_model: requester.RuntimeLLMModel
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
|
||||
|
||||
uuid: typing.Optional[str] = None
|
||||
"""该对话的 uuid,在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Session(pydantic.BaseModel):
|
||||
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
|
||||
|
||||
@@ -6,14 +6,12 @@ from .. import stage, app
|
||||
from ...utils import version, proxy, announce, platform
|
||||
from ...audit.center import v2 as center_v2
|
||||
from ...audit import identifier
|
||||
from ...pipeline import pool, controller, stagemgr, pipelinemgr
|
||||
from ...pipeline import pool, controller, pipelinemgr
|
||||
from ...plugin import manager as plugin_mgr
|
||||
from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...provider import runnermgr
|
||||
from ...platform import manager as im_mgr
|
||||
from ...persistence import mgr as persistencemgr
|
||||
from ...api.http.controller import main as http_controller
|
||||
@@ -61,10 +59,7 @@ class BuildAppStage(stage.BootingStage):
|
||||
},
|
||||
runtime_info={
|
||||
"admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]),
|
||||
"msg_source": str([
|
||||
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
|
||||
for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
|
||||
]),
|
||||
"msg_source": str([]),
|
||||
},
|
||||
)
|
||||
ap.ctr_mgr = center_v2_api
|
||||
@@ -99,26 +94,14 @@ class BuildAppStage(stage.BootingStage):
|
||||
await llm_session_mgr_inst.initialize()
|
||||
ap.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
ap.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
runner_mgr_inst = runnermgr.RunnerManager(ap)
|
||||
await runner_mgr_inst.initialize()
|
||||
ap.runner_mgr = runner_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
|
||||
stage_mgr = stagemgr.StageManager(ap)
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
pipeline_mgr = pipelinemgr.PipelineManager(ap)
|
||||
await pipeline_mgr.initialize()
|
||||
ap.pipeline_mgr = pipeline_mgr
|
||||
|
||||
@@ -233,3 +233,10 @@ class AsyncTaskManager:
|
||||
if not wrapper.task.done() and scope in wrapper.scopes:
|
||||
|
||||
wrapper.task.cancel()
|
||||
|
||||
def cancel_task(self, task_id: int):
|
||||
for wrapper in self.tasks:
|
||||
if wrapper.id == task_id:
|
||||
if not wrapper.task.done():
|
||||
wrapper.task.cancel()
|
||||
return
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
@@ -13,7 +13,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
仅检查query中群号或个人号是否在访问控制列表中。
|
||||
"""
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
@@ -24,9 +24,9 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
|
||||
found = False
|
||||
|
||||
mode = self.ap.pipeline_cfg.data['access-control']['mode']
|
||||
mode = query.pipeline_config['trigger']['access-control']['mode']
|
||||
|
||||
sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
|
||||
sess_list = query.pipeline_config['trigger']['access-control'][mode]
|
||||
|
||||
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
|
||||
or (query.launcher_type.value == 'person' and 'person_*' in sess_list):
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from ...core import app
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
@@ -35,17 +35,18 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
self.filter_chain = []
|
||||
super().__init__(ap)
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
|
||||
filters_required = [
|
||||
"content-ignore",
|
||||
]
|
||||
|
||||
if self.ap.pipeline_cfg.data['check-sensitive-words']:
|
||||
if pipeline_config['safety']['content-filter']['check-sensitive-words']:
|
||||
filters_required.append("ban-word-filter")
|
||||
|
||||
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
|
||||
filters_required.append("baidu-cloud-examine")
|
||||
# TODO revert it
|
||||
# if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
|
||||
# filters_required.append("baidu-cloud-examine")
|
||||
|
||||
for filter in filter_model.preregistered_filters:
|
||||
if filter.name in filters_required:
|
||||
@@ -65,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
只要有一个不通过就不放行,只放行 PASS 的消息
|
||||
"""
|
||||
|
||||
if not self.ap.pipeline_cfg.data['income-msg-check']:
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
@@ -73,7 +74,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
result = await filter.process(query, message)
|
||||
|
||||
if result.level in [
|
||||
filter_entities.ResultLevel.BLOCK,
|
||||
@@ -105,7 +106,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"""请求llm后处理响应
|
||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||
"""
|
||||
if message is None:
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
@@ -114,7 +115,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
message = message.strip()
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.POST in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
result = await filter.process(query, message)
|
||||
|
||||
if result.level == filter_entities.ResultLevel.BLOCK:
|
||||
return entities.StageProcessResult(
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
@@ -64,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
|
||||
@@ -4,6 +4,7 @@ import aiohttp
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
|
||||
@@ -26,7 +27,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
|
||||
from .. import filter as filter_model
|
||||
from .. import entities
|
||||
from ....config import manager as cfg_mgr
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@filter_model.filter_class("ban-word-filter")
|
||||
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@filter_model.filter_class("content-ignore")
|
||||
@@ -15,9 +16,9 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
entities.EnableStage.PRE,
|
||||
]
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
if 'prefix' in self.ap.pipeline_cfg.data['ignore-rules']:
|
||||
for rule in self.ap.pipeline_cfg.data['ignore-rules']['prefix']:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||
if message.startswith(rule):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
@@ -26,8 +27,8 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
|
||||
)
|
||||
|
||||
if 'regexp' in self.ap.pipeline_cfg.data['ignore-rules']:
|
||||
for rule in self.ap.pipeline_cfg.data['ignore-rules']['regexp']:
|
||||
if 'regexp' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']:
|
||||
if re.search(rule, message):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
|
||||
@@ -50,14 +50,23 @@ class Controller:
|
||||
continue
|
||||
|
||||
if selected_query:
|
||||
async def _process_query(selected_query):
|
||||
|
||||
async def _process_query(selected_query: entities.Query):
|
||||
async with self.semaphore: # 总并发上限
|
||||
await self.process_query(selected_query)
|
||||
# find pipeline
|
||||
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
||||
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
|
||||
if bot:
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid)
|
||||
if pipeline:
|
||||
await pipeline.run(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
self.ap.task_mgr.create_task(
|
||||
_process_query(selected_query),
|
||||
kind="query",
|
||||
@@ -70,127 +79,6 @@ class Controller:
|
||||
self.ap.logger.error(f"控制器循环出错: {e}")
|
||||
self.ap.logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
platform_message.Plain(result.user_notice)
|
||||
)
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
*result.user_notice
|
||||
)
|
||||
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
result.user_notice,
|
||||
query.adapter
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
if result.console_notice:
|
||||
self.ap.logger.info(result.console_notice)
|
||||
if result.error_notice:
|
||||
self.ap.logger.error(result.error_notice)
|
||||
|
||||
async def _execute_from_stage(
|
||||
self,
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
):
|
||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||
|
||||
如何看懂这里为什么这么写?
|
||||
去问 GPT-4:
|
||||
Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None],
|
||||
如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result,
|
||||
调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
|
||||
Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
现在假设C返回的是AsyncGenerator,那么执行顺序是:
|
||||
|
||||
A B C D E F G C D E F G C D E F G ...
|
||||
Q3: 但是如果不止一个stage会返回生成器呢?
|
||||
"""
|
||||
i = stage_index
|
||||
|
||||
while i < len(self.ap.stage_mgr.stage_containers):
|
||||
stage_container = self.ap.stage_mgr.stage_containers[i]
|
||||
|
||||
query.current_stage = stage_container # 标记到 Query 对象里
|
||||
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
if isinstance(result, typing.Coroutine):
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
"""
|
||||
try:
|
||||
|
||||
# ======== 触发 MessageReceived 事件 ========
|
||||
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_type(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
|
||||
async def run(self):
|
||||
"""运行控制器
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
|
||||
from ...core import app
|
||||
from . import strategy
|
||||
from .strategies import image, forward
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...platform.types import message as platform_message
|
||||
@@ -23,8 +23,8 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
strategy_impl: strategy.LongTextStrategy
|
||||
|
||||
async def initialize(self):
|
||||
config = self.ap.platform_cfg.data['long-text-process']
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
config = pipeline_config['output']['long-text-processing']
|
||||
if config['strategy'] == 'image':
|
||||
use_font = config['font-path']
|
||||
try:
|
||||
@@ -42,12 +42,12 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
else:
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
|
||||
|
||||
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
|
||||
except:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
|
||||
|
||||
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
|
||||
|
||||
for strategy_cls in strategy.preregistered_strategies:
|
||||
if strategy_cls.name == config['strategy']:
|
||||
@@ -69,7 +69,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
if contains_non_plain:
|
||||
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
|
||||
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
||||
elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']:
|
||||
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
|
||||
|
||||
return entities.StageProcessResult(
|
||||
|
||||
@@ -8,6 +8,7 @@ import re
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import functools
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
@@ -17,15 +18,18 @@ from ....core import entities as core_entities
|
||||
@strategy_model.strategy_class("image")
|
||||
class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
text_render_font: ImageFont.FreeTypeFont
|
||||
|
||||
async def initialize(self):
|
||||
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
|
||||
pass
|
||||
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def get_font(self, query: core_entities.Query):
|
||||
return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8")
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time()))
|
||||
save_as='temp/{}.png'.format(int(time.time())),
|
||||
query=query
|
||||
)
|
||||
|
||||
compressed_path, size = self.compress_image(
|
||||
@@ -127,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
return outfile, self.get_size(outfile)
|
||||
|
||||
|
||||
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
|
||||
def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None):
|
||||
|
||||
text_str = text_str.replace("\t", " ")
|
||||
|
||||
@@ -142,7 +146,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
|
||||
for line in lines:
|
||||
# 如果长了就分割
|
||||
line_width = self.text_render_font.getlength(line)
|
||||
line_width = self.get_font(query).getlength(line)
|
||||
self.ap.logger.debug("line_width: {}".format(line_width))
|
||||
if line_width < text_width:
|
||||
final_lines.append(line)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from . import truncator
|
||||
from .truncators import round
|
||||
@@ -14,8 +14,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
|
||||
"""
|
||||
trun: truncator.Truncator
|
||||
|
||||
async def initialize(self):
|
||||
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
use_method = "round"
|
||||
|
||||
for trun in truncator.preregistered_truncators:
|
||||
if trun.name == use_method:
|
||||
|
||||
@@ -12,7 +12,7 @@ class RoundTruncator(truncator.Truncator):
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
"""截断
|
||||
"""
|
||||
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
|
||||
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
||||
|
||||
temp_messages = []
|
||||
|
||||
|
||||
@@ -1,12 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ..core import app, entities
|
||||
from . import entities as pipeline_entities
|
||||
from ..entity.persistence import pipeline as persistence_pipeline
|
||||
from . import stagemgr, stage
|
||||
from . import stage
|
||||
from ..platform.types import message as platform_message, events as platform_events
|
||||
from ..plugin import events
|
||||
|
||||
from .resprule import resprule
|
||||
from .bansess import bansess
|
||||
from .cntfilter import cntfilter
|
||||
from .process import process
|
||||
from .longtext import longtext
|
||||
from .respback import respback
|
||||
from .wrapper import wrapper
|
||||
from .preproc import preproc
|
||||
from .ratelimit import ratelimit
|
||||
from .msgtrun import msgtrun
|
||||
|
||||
|
||||
class StageInstContainer():
|
||||
"""阶段实例容器
|
||||
"""
|
||||
|
||||
inst_name: str
|
||||
|
||||
inst: stage.PipelineStage
|
||||
|
||||
def __init__(self, inst_name: str, inst: stage.PipelineStage):
|
||||
self.inst_name = inst_name
|
||||
self.inst = inst
|
||||
|
||||
|
||||
class RuntimePipeline:
|
||||
@@ -17,16 +45,146 @@ class RuntimePipeline:
|
||||
pipeline_entity: persistence_pipeline.LegacyPipeline
|
||||
"""流水线实体"""
|
||||
|
||||
stage_containers: list[stagemgr.StageInstContainer]
|
||||
stage_containers: list[StageInstContainer]
|
||||
"""阶段实例容器"""
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]):
|
||||
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]):
|
||||
self.ap = ap
|
||||
self.pipeline_entity = pipeline_entity
|
||||
self.stage_containers = stage_containers
|
||||
|
||||
async def run(self):
|
||||
pass
|
||||
async def run(self, query: entities.Query):
|
||||
query.pipeline_config = self.pipeline_entity.config
|
||||
await self.process_query(query)
|
||||
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
platform_message.Plain(result.user_notice)
|
||||
)
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
*result.user_notice
|
||||
)
|
||||
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
|
||||
result.user_notice.insert(
|
||||
0,
|
||||
platform_message.At(
|
||||
query.message_event.sender.id
|
||||
)
|
||||
)
|
||||
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=result.user_notice,
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin']
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
if result.console_notice:
|
||||
self.ap.logger.info(result.console_notice)
|
||||
if result.error_notice:
|
||||
self.ap.logger.error(result.error_notice)
|
||||
|
||||
async def _execute_from_stage(
|
||||
self,
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
):
|
||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||
|
||||
如何看懂这里为什么这么写?
|
||||
去问 GPT-4:
|
||||
Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None],
|
||||
如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result,
|
||||
调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
|
||||
Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
现在假设C返回的是AsyncGenerator,那么执行顺序是:
|
||||
|
||||
A B C D E F G C D E F G C D E F G ...
|
||||
Q3: 但是如果不止一个stage会返回生成器呢?
|
||||
"""
|
||||
i = stage_index
|
||||
|
||||
while i < len(self.stage_containers):
|
||||
stage_container = self.stage_containers[i]
|
||||
|
||||
query.current_stage = stage_container # 标记到 Query 对象里
|
||||
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
if isinstance(result, typing.Coroutine):
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
"""
|
||||
try:
|
||||
|
||||
# ======== 触发 MessageReceived 事件 ========
|
||||
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_type(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
|
||||
|
||||
class PipelineManager:
|
||||
@@ -70,12 +228,15 @@ class PipelineManager:
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
||||
|
||||
# initialize stage containers according to pipeline_entity.stages
|
||||
stage_containers = []
|
||||
stage_containers: list[StageInstContainer] = []
|
||||
for stage_name in pipeline_entity.stages:
|
||||
stage_containers.append(stagemgr.StageInstContainer(
|
||||
stage_name=stage_name,
|
||||
stage_class=self.stage_dict[stage_name]
|
||||
stage_containers.append(StageInstContainer(
|
||||
inst_name=stage_name,
|
||||
inst=self.stage_dict[stage_name](self.ap)
|
||||
))
|
||||
|
||||
for stage_container in stage_containers:
|
||||
await stage_container.inst.initialize(pipeline_entity.config)
|
||||
|
||||
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
|
||||
self.pipelines.append(runtime_pipeline)
|
||||
|
||||
@@ -28,15 +28,17 @@ class QueryPool:
|
||||
|
||||
async def add_query(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
launcher_type: entities.LauncherTypes,
|
||||
launcher_id: typing.Union[int, str],
|
||||
sender_id: typing.Union[int, str],
|
||||
message_event: platform_events.MessageEvent,
|
||||
message_chain: platform_message.MessageChain,
|
||||
adapter: msadapter.MessagePlatformAdapter
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
) -> entities.Query:
|
||||
async with self.condition:
|
||||
query = entities.Query(
|
||||
bot_uuid=bot_uuid,
|
||||
query_id=self.query_id_counter,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...provider import entities as llm_entities
|
||||
from ...plugin import events
|
||||
@@ -33,16 +33,16 @@ class PreProcessor(stage.PipelineStage):
|
||||
"""
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(session)
|
||||
conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config['ai']['local-agent']['prompt'])
|
||||
|
||||
# 设置query
|
||||
query.session = session
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
query.use_model = conversation.use_model
|
||||
query.use_llm_model = conversation.use_llm_model
|
||||
|
||||
query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
|
||||
query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
|
||||
|
||||
query.variables = {
|
||||
"session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
@@ -50,8 +50,9 @@ class PreProcessor(stage.PipelineStage):
|
||||
"msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()),
|
||||
}
|
||||
|
||||
# 检查vision是否启用,没启用就删除所有图片
|
||||
if not self.ap.provider_cfg.data['enable-vision'] or (self.ap.provider_cfg.data['runner'] == 'local-agent' and not query.use_model.vision_supported):
|
||||
# Check if this model supports vision, if not, remove all images
|
||||
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
||||
if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
@@ -69,7 +70,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
)
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported):
|
||||
if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if me.base64 is not None:
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_image_base64(me.base64)
|
||||
|
||||
@@ -9,7 +9,9 @@ import json
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import entities as llm_entities, runnermgr
|
||||
from ....provider import entities as llm_entities
|
||||
from ....provider import runner as runner_module
|
||||
from ....provider.runners import localagent, difysvapi, dashscopeapi
|
||||
from ....plugin import events
|
||||
|
||||
from ....platform.types import message as platform_message
|
||||
@@ -56,12 +58,6 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
)
|
||||
else:
|
||||
|
||||
if not self.ap.provider_cfg.data['enable-chat']:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
if event_ctx.event.alter is not None:
|
||||
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
|
||||
query.user_message.content = event_ctx.event.alter
|
||||
@@ -72,7 +68,12 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
try:
|
||||
|
||||
runner = self.ap.runner_mgr.get_runner()
|
||||
for r in runner_module.preregistered_runners:
|
||||
if r.name == query.pipeline_config["ai"]["runner"]["runner"]:
|
||||
runner = r(self.ap, query.pipeline_config)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}")
|
||||
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
@@ -93,10 +94,12 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
|
||||
|
||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}',
|
||||
user_notice='请求失败' if hide_exception_info else f'{e}',
|
||||
error_notice=f'{e}',
|
||||
debug_notice=traceback.format_exc()
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from ...core import app, entities as core_entities
|
||||
from . import handler
|
||||
from .handlers import chat, command
|
||||
from .. import entities
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
@@ -23,7 +23,7 @@ class Processor(stage.PipelineStage):
|
||||
|
||||
chat_handler: handler.MessageHandler
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
self.cmd_handler = command.CommandHandler(self.ap)
|
||||
self.chat_handler = chat.ChatMessageHandler(self.ap)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import app, entities as core_entities
|
||||
|
||||
|
||||
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
||||
@@ -31,7 +31,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
"""进入处理流程
|
||||
|
||||
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
|
||||
@@ -46,7 +46,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
"""退出处理流程
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
import time
|
||||
import typing
|
||||
from .. import algo
|
||||
from ....core import entities as core_entities
|
||||
|
||||
# 固定窗口算法
|
||||
class SessionContainer:
|
||||
@@ -30,7 +31,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
self.containers_lock = asyncio.Lock()
|
||||
self.containers = {}
|
||||
|
||||
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
# 加锁,找容器
|
||||
container: SessionContainer = None
|
||||
|
||||
@@ -47,12 +48,13 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
async with container.wait_lock:
|
||||
|
||||
# 获取窗口大小和限制
|
||||
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size']
|
||||
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit']
|
||||
window_size = query.pipeline_config['safety']['rate-limit']['window-length']
|
||||
limitation = query.pipeline_config['safety']['rate-limit']['limitation']
|
||||
|
||||
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
|
||||
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
|
||||
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
|
||||
# TODO revert it
|
||||
# if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
|
||||
# window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
|
||||
# limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
|
||||
|
||||
# 获取当前时间戳
|
||||
now = int(time.time())
|
||||
@@ -65,9 +67,9 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
# 如果访问次数超过了限制
|
||||
if count >= limitation:
|
||||
if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop':
|
||||
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
|
||||
return False
|
||||
elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait':
|
||||
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
|
||||
# 等待下一窗口
|
||||
await asyncio.sleep(window_size - time.time() % window_size)
|
||||
|
||||
@@ -84,5 +86,5 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
# 返回True
|
||||
return True
|
||||
|
||||
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
pass
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import entities, stagemgr, stage
|
||||
from .. import entities, stage
|
||||
from . import algo
|
||||
from .algos import fixedwin
|
||||
from ...core import entities as core_entities
|
||||
@@ -18,9 +18,9 @@ class RateLimit(stage.PipelineStage):
|
||||
|
||||
algo: algo.ReteLimitAlgo
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
|
||||
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
|
||||
algo_name = 'fixwin'
|
||||
|
||||
algo_class = None
|
||||
|
||||
@@ -46,6 +46,7 @@ class RateLimit(stage.PipelineStage):
|
||||
"""
|
||||
if stage_inst_name == "RequireRateLimitOccupancy":
|
||||
if await self.algo.require_access(
|
||||
query,
|
||||
query.launcher_type.value,
|
||||
query.launcher_id,
|
||||
):
|
||||
@@ -62,6 +63,7 @@ class RateLimit(stage.PipelineStage):
|
||||
)
|
||||
elif stage_inst_name == "ReleaseRateLimitOccupancy":
|
||||
await self.algo.release_access(
|
||||
query,
|
||||
query.launcher_type.value,
|
||||
query.launcher_id,
|
||||
)
|
||||
|
||||
@@ -5,8 +5,10 @@ import asyncio
|
||||
|
||||
|
||||
from ...core import app
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
@@ -19,8 +21,8 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
|
||||
random_range = (self.ap.platform_cfg.data['force-delay']['min'], self.ap.platform_cfg.data['force-delay']['max'])
|
||||
|
||||
random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max'])
|
||||
|
||||
random_delay = random.uniform(*random_range)
|
||||
|
||||
@@ -31,10 +33,20 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
|
||||
await asyncio.sleep(random_delay)
|
||||
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
query.resp_message_chain[-1],
|
||||
adapter=query.adapter
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
|
||||
query.resp_message_chain[-1].insert(
|
||||
0,
|
||||
platform_message.At(
|
||||
query.message_event.sender.id
|
||||
)
|
||||
)
|
||||
|
||||
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
|
||||
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=query.resp_message_chain[-1],
|
||||
quote_origin=quote_origin
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
|
||||
@@ -5,7 +5,7 @@ from ...core import app
|
||||
from . import entities as rule_entities, rule
|
||||
from .rules import atbot, prefix, regexp, random
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
@@ -20,7 +20,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
rule_matchers: list[rule.GroupRespondRule]
|
||||
"""检查器实例"""
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
"""初始化检查器
|
||||
"""
|
||||
|
||||
@@ -39,12 +39,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
new_query=query
|
||||
)
|
||||
|
||||
rules = self.ap.pipeline_cfg.data['respond-rules']
|
||||
rules = query.pipeline_config['trigger']['group-respond-rules']
|
||||
|
||||
use_rule = rules['default']
|
||||
use_rule = rules
|
||||
|
||||
if str(query.launcher_id) in rules:
|
||||
use_rule = rules[str(query.launcher_id)]
|
||||
# TODO revert it
|
||||
# if str(query.launcher_id) in rules:
|
||||
# use_rule = rules[str(query.launcher_id)]
|
||||
|
||||
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
|
||||
|
||||
@@ -28,7 +28,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
"""初始化
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..core import app
|
||||
from . import stage
|
||||
from .resprule import resprule
|
||||
from .bansess import bansess
|
||||
from .cntfilter import cntfilter
|
||||
from .process import process
|
||||
from .longtext import longtext
|
||||
from .respback import respback
|
||||
from .wrapper import wrapper
|
||||
from .preproc import preproc
|
||||
from .ratelimit import ratelimit
|
||||
from .msgtrun import msgtrun
|
||||
|
||||
|
||||
# 请求处理阶段顺序
|
||||
stage_order = [
|
||||
"GroupRespondRuleCheckStage", # 群响应规则检查
|
||||
"BanSessionCheckStage", # 封禁会话检查
|
||||
"PreContentFilterStage", # 内容过滤前置阶段
|
||||
"PreProcessor", # 预处理器
|
||||
"ConversationMessageTruncator", # 会话消息截断器
|
||||
"RequireRateLimitOccupancy", # 请求速率限制占用
|
||||
"MessageProcessor", # 处理器
|
||||
"ReleaseRateLimitOccupancy", # 释放速率限制占用
|
||||
"PostContentFilterStage", # 内容过滤后置阶段
|
||||
"ResponseWrapper", # 响应包装器
|
||||
"LongTextProcessStage", # 长文本处理
|
||||
"SendResponseBackStage", # 发送响应
|
||||
]
|
||||
|
||||
|
||||
class StageInstContainer():
|
||||
"""阶段实例容器
|
||||
"""
|
||||
|
||||
inst_name: str
|
||||
|
||||
inst: stage.PipelineStage
|
||||
|
||||
def __init__(self, inst_name: str, inst: stage.PipelineStage):
|
||||
self.inst_name = inst_name
|
||||
self.inst = inst
|
||||
|
||||
|
||||
class StageManager:
|
||||
ap: app.Application
|
||||
|
||||
stage_containers: list[StageInstContainer]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.stage_containers = []
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化
|
||||
"""
|
||||
|
||||
for name, cls in stage.preregistered_stages.items():
|
||||
self.stage_containers.append(StageInstContainer(
|
||||
inst_name=name,
|
||||
inst=cls(self.ap)
|
||||
))
|
||||
|
||||
for stage_containers in self.stage_containers:
|
||||
await stage_containers.inst.initialize()
|
||||
|
||||
# 按照 stage_order 排序
|
||||
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))
|
||||
@@ -5,7 +5,7 @@ import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...plugin import events
|
||||
@@ -22,7 +22,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
- resp_message_chain
|
||||
"""
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
@@ -110,7 +110,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
|
||||
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
|
||||
|
||||
if self.ap.platform_cfg.data['track-function-calls']:
|
||||
if query.pipeline_config['output']['misc']['track-function-calls']:
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
|
||||
@@ -50,6 +50,41 @@ class RuntimeBot:
|
||||
self.adapter = adapter
|
||||
self.task_context = taskmgr.TaskContext()
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
bot_uuid=self.bot_entity.uuid,
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
self.adapter.register_listener(
|
||||
platform_events.FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
self.adapter.register_listener(
|
||||
platform_events.GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
|
||||
async def exception_wrapper():
|
||||
@@ -78,14 +113,16 @@ class RuntimeBot:
|
||||
async def shutdown(self):
|
||||
await self.adapter.kill()
|
||||
|
||||
self.ap.task_mgr.cancel_task(self.task_wrapper.id)
|
||||
|
||||
|
||||
# 控制QQ消息输入输出的类
|
||||
class PlatformManager:
|
||||
|
||||
# adapter: msadapter.MessageSourceAdapter = None
|
||||
adapters: list[msadapter.MessagePlatformAdapter] = []
|
||||
adapters: list[msadapter.MessagePlatformAdapter] = [] # deprecated
|
||||
|
||||
message_platform_adapter_components: list[engine.Component] = []
|
||||
message_platform_adapter_components: list[engine.Component] = [] # deprecated
|
||||
|
||||
# ====== 4.0 ======
|
||||
ap: app.Application = None
|
||||
@@ -135,49 +172,20 @@ class PlatformManager:
|
||||
bot_entity = persistence_bot.Bot(**bot_entity._mapping)
|
||||
elif isinstance(bot_entity, dict):
|
||||
bot_entity = persistence_bot.Bot(**bot_entity)
|
||||
|
||||
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
||||
bot_entity.adapter_config,
|
||||
self.ap
|
||||
)
|
||||
|
||||
adapter_inst.register_listener(
|
||||
platform_events.FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
adapter_inst.register_listener(
|
||||
platform_events.GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
runtime_bot = RuntimeBot(
|
||||
ap=self.ap,
|
||||
bot_entity=bot_entity,
|
||||
adapter=adapter_inst
|
||||
)
|
||||
|
||||
await runtime_bot.initialize()
|
||||
|
||||
self.bots.append(runtime_bot)
|
||||
|
||||
return runtime_bot
|
||||
@@ -209,50 +217,36 @@ class PlatformManager:
|
||||
return None
|
||||
|
||||
async def write_back_config(self, adapter_name: str, adapter_inst: msadapter.MessagePlatformAdapter, config: dict):
|
||||
index = -2
|
||||
# index = -2
|
||||
|
||||
for i, adapter in enumerate(self.adapters):
|
||||
if adapter == adapter_inst:
|
||||
index = i
|
||||
break
|
||||
# for i, adapter in enumerate(self.adapters):
|
||||
# if adapter == adapter_inst:
|
||||
# index = i
|
||||
# break
|
||||
|
||||
if index == -2:
|
||||
raise Exception('平台适配器未找到')
|
||||
# if index == -2:
|
||||
# raise Exception('平台适配器未找到')
|
||||
|
||||
# 只修改启用的适配器
|
||||
real_index = -1
|
||||
# # 只修改启用的适配器
|
||||
# real_index = -1
|
||||
|
||||
for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
|
||||
if adapter['enable']:
|
||||
index -= 1
|
||||
if index == -1:
|
||||
real_index = i
|
||||
break
|
||||
# for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
|
||||
# if adapter['enable']:
|
||||
# index -= 1
|
||||
# if index == -1:
|
||||
# real_index = i
|
||||
# break
|
||||
|
||||
new_cfg = {
|
||||
'adapter': adapter_name,
|
||||
'enable': True,
|
||||
**config
|
||||
}
|
||||
self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
|
||||
await self.ap.platform_cfg.dump_config()
|
||||
# new_cfg = {
|
||||
# 'adapter': adapter_name,
|
||||
# 'enable': True,
|
||||
# **config
|
||||
# }
|
||||
# self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
|
||||
# await self.ap.platform_cfg.dump_config()
|
||||
|
||||
async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage):
|
||||
|
||||
msg.insert(
|
||||
0,
|
||||
platform_message.At(
|
||||
event.sender.id
|
||||
)
|
||||
)
|
||||
|
||||
await adapter.reply_message(
|
||||
event,
|
||||
msg,
|
||||
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False
|
||||
)
|
||||
# TODO implement this
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
# This method will only be called when the application launching
|
||||
@@ -264,4 +258,4 @@ class PlatformManager:
|
||||
for bot in self.bots:
|
||||
if bot.enable:
|
||||
await bot.shutdown()
|
||||
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)
|
||||
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)
|
||||
|
||||
@@ -238,4 +238,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
await self.bot._server_app.run_task(**self.config)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
# Current issue: existing connection will not be closed
|
||||
# self.should_shutdown = True
|
||||
return False
|
||||
|
||||
@@ -28,9 +28,11 @@ spec:
|
||||
label:
|
||||
en_US: Intents
|
||||
zh_CN: 权限
|
||||
type: array[string]
|
||||
type: array
|
||||
required: true
|
||||
default: []
|
||||
items:
|
||||
type: string
|
||||
execution:
|
||||
python:
|
||||
path: ./qqbotpy.py
|
||||
|
||||
@@ -222,10 +222,10 @@ class EventContext:
|
||||
Args:
|
||||
message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
|
||||
"""
|
||||
await self.host.ap.platform_mgr.send(
|
||||
event=self.event.query.message_event,
|
||||
msg=message_chain,
|
||||
adapter=self.event.query.adapter,
|
||||
# TODO 添加 at_sender 和 quote_origin 参数
|
||||
await self.event.query.adapter.reply_message(
|
||||
message_source=self.event.query.message_event,
|
||||
message=message_chain
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
|
||||
@@ -4,6 +4,8 @@ import typing
|
||||
import enum
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from pkg.provider import entities
|
||||
|
||||
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
@@ -124,3 +126,13 @@ class Message(pydantic.BaseModel):
|
||||
mc.insert(0, platform_message.Plain(prefix_text))
|
||||
|
||||
return platform_message.MessageChain(mc)
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
"""名称"""
|
||||
|
||||
messages: list[entities.Message]
|
||||
"""消息列表"""
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import sqlalchemy
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from . import entities, requester
|
||||
from ...core import app
|
||||
@@ -16,23 +17,6 @@ from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcm
|
||||
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
||||
|
||||
|
||||
class RuntimeLLMModel:
|
||||
"""运行时模型"""
|
||||
|
||||
model_entity: persistence_model.LLMModel
|
||||
"""模型数据"""
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
"""api key管理器"""
|
||||
|
||||
requester: requester.LLMAPIRequester
|
||||
"""请求器实例"""
|
||||
|
||||
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester):
|
||||
self.model_entity = model_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""模型管理器"""
|
||||
@@ -47,7 +31,7 @@ class ModelManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
llm_models: list[RuntimeLLMModel]
|
||||
llm_models: list[requester.RuntimeLLMModel]
|
||||
|
||||
requester_components: list[engine.Component]
|
||||
|
||||
@@ -99,16 +83,20 @@ class ModelManager:
|
||||
elif isinstance(model_info, dict):
|
||||
model_info = persistence_model.LLMModel(**model_info)
|
||||
|
||||
runtime_llm_model = RuntimeLLMModel(
|
||||
requester_inst = self.requester_dict[model_info.requester](
|
||||
ap=self.ap,
|
||||
config=model_info.requester_config
|
||||
)
|
||||
|
||||
await requester_inst.initialize()
|
||||
|
||||
runtime_llm_model = requester.RuntimeLLMModel(
|
||||
model_entity=model_info,
|
||||
token_mgr=token.TokenManager(
|
||||
name=model_info.uuid,
|
||||
tokens=model_info.api_keys,
|
||||
),
|
||||
requester=self.requester_dict[model_info.requester](
|
||||
ap=self.ap,
|
||||
config=model_info.requester_config
|
||||
)
|
||||
requester=requester_inst
|
||||
)
|
||||
self.llm_models.append(runtime_llm_model)
|
||||
|
||||
|
||||
@@ -6,8 +6,27 @@ import typing
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from . import entities as modelmgr_entities
|
||||
from ..tools import entities as tools_entities
|
||||
from ...entity.persistence import model as persistence_model
|
||||
from . import token
|
||||
|
||||
|
||||
class RuntimeLLMModel:
|
||||
"""运行时模型"""
|
||||
|
||||
model_entity: persistence_model.LLMModel
|
||||
"""模型数据"""
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
"""api key管理器"""
|
||||
|
||||
requester: LLMAPIRequester
|
||||
"""请求器实例"""
|
||||
|
||||
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester):
|
||||
self.model_entity = model_entity
|
||||
self.token_mgr = token_mgr
|
||||
self.requester = requester
|
||||
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
@@ -31,21 +50,11 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def preprocess(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
):
|
||||
"""预处理
|
||||
|
||||
在这里处理特定API对Query对象的兼容性问题。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def call(
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: modelmgr_entities.LLMModelInfo,
|
||||
model: RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
@@ -53,7 +62,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""调用API
|
||||
|
||||
Args:
|
||||
model (modelmgr_entities.LLMModelInfo): 使用的模型信息
|
||||
model (RuntimeLLMModel): 使用的模型信息
|
||||
messages (typing.List[llm_entities.Message]): 消息对象列表
|
||||
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
|
||||
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
|
||||
|
||||
@@ -24,16 +24,16 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
client: anthropic.AsyncAnthropic
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.anthropic.com/v1',
|
||||
'base_url': 'https://api.anthropic.com/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
httpx_client = anthropic._base_client.AsyncHttpxClientWrapper(
|
||||
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
|
||||
base_url=self.requester_cfg['base_url'],
|
||||
# cast to a valid type because mypy doesn't understand our type narrowing
|
||||
timeout=typing.cast(httpx.Timeout, self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout']),
|
||||
timeout=typing.cast(httpx.Timeout, self.requester_cfg['timeout']),
|
||||
limits=anthropic._constants.DEFAULT_CONNECTION_LIMITS,
|
||||
follow_redirects=True,
|
||||
trust_env=True,
|
||||
@@ -44,17 +44,18 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
||||
http_client=httpx_client,
|
||||
)
|
||||
|
||||
async def call(
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: entities.LLMModelInfo,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
|
||||
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
|
||||
args["model"] = model.name if model.model_name is None else model.model_name
|
||||
args = extra_args.copy()
|
||||
args["model"] = model.model_entity.name
|
||||
|
||||
# 处理消息
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: Anthropic
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class BailianChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 阿里云百炼
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -26,7 +26,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
"base-url": "https://api.openai.com/v1",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
|
||||
self.client = openai.AsyncClient(
|
||||
api_key="",
|
||||
base_url=self.requester_cfg["base-url"],
|
||||
base_url=self.requester_cfg["base_url"],
|
||||
timeout=self.requester_cfg["timeout"],
|
||||
http_client=httpx.AsyncClient(
|
||||
trust_env=True, timeout=self.requester_cfg["timeout"]
|
||||
@@ -51,7 +51,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
self,
|
||||
chat_completion: chat_completion.ChatCompletion,
|
||||
) -> llm_entities.Message:
|
||||
chatcmpl_message = chat_completion.choices[0].message.dict()
|
||||
chatcmpl_message = chat_completion.choices[0].message.model_dump()
|
||||
|
||||
# 确保 role 字段存在且不为 None
|
||||
if "role" not in chatcmpl_message or chatcmpl_message["role"] is None:
|
||||
@@ -65,16 +65,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg["args"].copy()
|
||||
args["model"] = (
|
||||
use_model.name if use_model.model_name is None else use_model.model_name
|
||||
)
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
@@ -104,10 +102,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||
|
||||
return message
|
||||
|
||||
async def call(
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: entities.LLMModelInfo,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: OpenAI
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -13,7 +13,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""Deepseek ChatCompletion API 请求器"""
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.deepseek.com',
|
||||
'base_url': 'https://api.deepseek.com',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -21,14 +21,14 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 深度求索
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -18,7 +18,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""Gitee AI ChatCompletions API 请求器"""
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://ai.gitee.com/v1',
|
||||
'base_url': 'https://ai.gitee.com/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -26,14 +26,14 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: Gitee AI
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'http://127.0.0.1:1234/v1',
|
||||
'base_url': 'http://127.0.0.1:1234/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: LM Studio
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -15,7 +15,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""Moonshot ChatCompletion API 请求器"""
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.moonshot.cn/v1',
|
||||
'base_url': 'https://api.moonshot.cn/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -23,14 +23,14 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 月之暗面
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -22,35 +22,38 @@ REQUESTER_NAME: str = "ollama-chat"
|
||||
|
||||
class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
"""Ollama平台 ChatCompletion API请求器"""
|
||||
|
||||
client: ollama.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'http://127.0.0.1:11434',
|
||||
'timeout': 120,
|
||||
"base_url": "http://127.0.0.1:11434",
|
||||
"timeout": 120,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
os.environ['OLLAMA_HOST'] = self.requester_cfg['base-url']
|
||||
self.client = ollama.AsyncClient(
|
||||
timeout=self.requester_cfg['timeout']
|
||||
)
|
||||
os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"]
|
||||
self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"])
|
||||
|
||||
async def _req(self,
|
||||
args: dict,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
return await self.client.chat(
|
||||
**args
|
||||
)
|
||||
async def _req(
|
||||
self,
|
||||
args: dict,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
return await self.client.chat(**args)
|
||||
|
||||
async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo,
|
||||
user_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {}) -> llm_entities.Message:
|
||||
args: Any = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
async def _closure(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
user_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
args = extra_args.copy()
|
||||
args["model"] = use_model.model_entity.name
|
||||
|
||||
messages: list[dict] = req_messages.copy()
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg["content"], list):
|
||||
if "content" in msg and isinstance(msg["content"], list):
|
||||
text_content: list = []
|
||||
image_urls: list = []
|
||||
for me in msg["content"]:
|
||||
@@ -58,12 +61,16 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
text_content.append(me["text"])
|
||||
elif me["type"] == "image_base64":
|
||||
image_urls.append(me["image_base64"])
|
||||
|
||||
|
||||
msg["content"] = "\n".join(text_content)
|
||||
msg["images"] = [url.split(',')[1] for url in image_urls]
|
||||
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
|
||||
for tool_call in msg['tool_calls']:
|
||||
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
|
||||
msg["images"] = [url.split(",")[1] for url in image_urls]
|
||||
if (
|
||||
"tool_calls" in msg
|
||||
): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
|
||||
for tool_call in msg["tool_calls"]:
|
||||
tool_call["function"]["arguments"] = json.loads(
|
||||
tool_call["function"]["arguments"]
|
||||
)
|
||||
args["messages"] = messages
|
||||
|
||||
args["tools"] = []
|
||||
@@ -77,8 +84,8 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
return message
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completions: ollama.ChatResponse) -> llm_entities.Message:
|
||||
self, chat_completions: ollama.ChatResponse
|
||||
) -> llm_entities.Message:
|
||||
message: ollama.Message = chat_completions.message
|
||||
if message is None:
|
||||
raise ValueError("chat_completions must contain a 'message' field")
|
||||
@@ -86,43 +93,51 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
||||
ret_msg: llm_entities.Message = None
|
||||
|
||||
if message.content is not None:
|
||||
ret_msg = llm_entities.Message(
|
||||
role="assistant",
|
||||
content=message.content
|
||||
)
|
||||
ret_msg = llm_entities.Message(role="assistant", content=message.content)
|
||||
if message.tool_calls is not None and len(message.tool_calls) > 0:
|
||||
tool_calls: list[llm_entities.ToolCall] = []
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
tool_calls.append(llm_entities.ToolCall(
|
||||
id=uuid.uuid4().hex,
|
||||
type="function",
|
||||
function=llm_entities.FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=json.dumps(tool_call.function.arguments)
|
||||
tool_calls.append(
|
||||
llm_entities.ToolCall(
|
||||
id=uuid.uuid4().hex,
|
||||
type="function",
|
||||
function=llm_entities.FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=json.dumps(tool_call.function.arguments),
|
||||
),
|
||||
)
|
||||
))
|
||||
)
|
||||
ret_msg.tool_calls = tool_calls
|
||||
|
||||
return ret_msg
|
||||
|
||||
async def call(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
req_messages: list = []
|
||||
for m in messages:
|
||||
msg_dict: dict = m.dict(exclude_none=True)
|
||||
content: Any = msg_dict.get("content")
|
||||
if isinstance(content, list):
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
if all(
|
||||
isinstance(part, dict) and part.get("type") == "text"
|
||||
for part in content
|
||||
):
|
||||
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
try:
|
||||
return await self._closure(query, req_messages, model, funcs, extra_args)
|
||||
return await self._closure(
|
||||
query=query,
|
||||
req_messages=req_messages,
|
||||
use_model=model,
|
||||
use_funcs=funcs,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
raise errors.RequesterError("请求超时")
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: Ollama
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.siliconflow.cn/v1',
|
||||
'base_url': 'https://api.siliconflow.cn/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 硅基流动
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://ark.cn-beijing.volces.com/api/v3',
|
||||
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 火山方舟
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://api.x.ai/v1',
|
||||
'base_url': 'https://api.x.ai/v1',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: xAI
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -14,6 +14,6 @@ class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
client: openai.AsyncClient
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base-url': 'https://open.bigmodel.cn/api/paas/v4',
|
||||
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
zh_CN: 智谱 AI
|
||||
spec:
|
||||
config:
|
||||
- name: base-url
|
||||
- name: base_url
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_CN: 基础 URL
|
||||
|
||||
@@ -27,11 +27,11 @@ class RequestRunner(abc.ABC):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
pipeline_config: dict
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import runner
|
||||
from ..core import app
|
||||
|
||||
from .runners import localagent
|
||||
from .runners import difysvapi
|
||||
from .runners import dashscopeapi
|
||||
|
||||
class RunnerManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
using_runner: runner.RequestRunner
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
for r in runner.preregistered_runners:
|
||||
if r.name == self.ap.provider_cfg.data['runner']:
|
||||
self.using_runner = r(self.ap)
|
||||
await self.using_runner.initialize()
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未找到请求运行器: {self.ap.provider_cfg.data['runner']}")
|
||||
|
||||
def get_runner(self) -> runner.RequestRunner:
|
||||
return self.using_runner
|
||||
@@ -8,7 +8,7 @@ import re
|
||||
import dashscope
|
||||
|
||||
from .. import runner
|
||||
from ...core import entities as core_entities
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from ...utils import image
|
||||
|
||||
@@ -29,12 +29,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
app_id: str # 应用ID
|
||||
api_key: str # API Key
|
||||
references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置)
|
||||
biz_params: dict = {} # 工作流应用参数(仅在工作流应用中生效)
|
||||
|
||||
async def initialize(self):
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
"""初始化"""
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
valid_app_types = ["agent", "workflow"]
|
||||
self.app_type = self.ap.provider_cfg.data["dashscope-app-api"]["app-type"]
|
||||
self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"]
|
||||
#检查配置文件中使用的应用类型是否支持
|
||||
if (self.app_type not in valid_app_types):
|
||||
raise DashscopeAPIError(
|
||||
@@ -42,10 +44,9 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
)
|
||||
|
||||
#初始化Dashscope 参数配置
|
||||
self.app_id = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["app-id"]
|
||||
self.api_key = self.ap.provider_cfg.data["dashscope-app-api"]["api-key"]
|
||||
self.references_quote = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["references_quote"]
|
||||
self.biz_params = self.ap.provider_cfg.data["dashscope-app-api"]["workflow"]["biz_params"]
|
||||
self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"]
|
||||
self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"]
|
||||
self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"]
|
||||
|
||||
def _replace_references(self, text, references_dict):
|
||||
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
|
||||
@@ -169,7 +170,6 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
biz_params = {}
|
||||
biz_params.update(self.biz_params)
|
||||
biz_params.update(query.variables)
|
||||
|
||||
#发送对话请求
|
||||
@@ -220,21 +220,19 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行"""
|
||||
if self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "agent":
|
||||
if self.app_type == "agent":
|
||||
async for msg in self._agent_messages(query):
|
||||
yield msg
|
||||
elif self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "workflow":
|
||||
elif self.app_type == "workflow":
|
||||
async for msg in self._workflow_messages(query):
|
||||
yield msg
|
||||
else:
|
||||
raise DashscopeAPIError(
|
||||
f"不支持的 Dashscope 应用类型: {self.ap.provider_cfg.data['dashscope-app-api']['app-type']}"
|
||||
f"不支持的 Dashscope 应用类型: {self.app_type}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import datetime
|
||||
import aiohttp
|
||||
|
||||
from .. import runner
|
||||
from ...core import entities as core_entities
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from ...utils import image
|
||||
|
||||
@@ -23,24 +23,24 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
dify_client: client.AsyncDifyServiceClient
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化"""
|
||||
def __init__(self, ap: app.Application, pipeline_config: dict):
|
||||
self.ap = ap
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
valid_app_types = ["chat", "agent", "workflow"]
|
||||
if (
|
||||
self.ap.provider_cfg.data["dify-service-api"]["app-type"]
|
||||
self.pipeline_config["ai"]["dify-service-api"]["app-type"]
|
||||
not in valid_app_types
|
||||
):
|
||||
raise errors.DifyAPIError(
|
||||
f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}"
|
||||
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
|
||||
)
|
||||
|
||||
api_key = self.ap.provider_cfg.data["dify-service-api"][
|
||||
self.ap.provider_cfg.data["dify-service-api"]["app-type"]
|
||||
]["api-key"]
|
||||
api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"]
|
||||
|
||||
self.dify_client = client.AsyncDifyServiceClient(
|
||||
api_key=api_key,
|
||||
base_url=self.ap.provider_cfg.data["dify-service-api"]["base-url"],
|
||||
base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"],
|
||||
)
|
||||
|
||||
def _try_convert_thinking(self, resp_text: str) -> str:
|
||||
@@ -48,13 +48,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
if not resp_text.startswith("<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>"):
|
||||
return resp_text
|
||||
|
||||
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "original":
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original":
|
||||
return resp_text
|
||||
|
||||
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "remove":
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove":
|
||||
return re.sub(r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>', '', resp_text, flags=re.DOTALL)
|
||||
|
||||
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "plain":
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain":
|
||||
pattern = r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>(.*?)</details>'
|
||||
thinking_text = re.search(pattern, resp_text, flags=re.DOTALL)
|
||||
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
|
||||
@@ -121,7 +121,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
|
||||
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
|
||||
):
|
||||
self.ap.logger.debug("dify-chat-chunk: " + str(chunk))
|
||||
|
||||
@@ -177,7 +177,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
response_mode="streaming",
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
|
||||
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
|
||||
):
|
||||
self.ap.logger.debug("dify-agent-chunk: " + str(chunk))
|
||||
|
||||
@@ -264,7 +264,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
inputs=inputs,
|
||||
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
files=files,
|
||||
timeout=self.ap.provider_cfg.data["dify-service-api"]["workflow"]["timeout"],
|
||||
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
|
||||
):
|
||||
self.ap.logger.debug("dify-workflow-chunk: " + str(chunk))
|
||||
if chunk["event"] in ignored_events:
|
||||
@@ -301,11 +301,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="assistant",
|
||||
content=chunk["data"]["outputs"][
|
||||
self.ap.provider_cfg.data["dify-service-api"]["workflow"][
|
||||
"output-key"
|
||||
]
|
||||
],
|
||||
content=chunk["data"]["outputs"]["summary"],
|
||||
)
|
||||
|
||||
yield msg
|
||||
@@ -314,16 +310,16 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求"""
|
||||
if self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "chat":
|
||||
if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat":
|
||||
async for msg in self._chat_messages(query):
|
||||
yield msg
|
||||
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "agent":
|
||||
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent":
|
||||
async for msg in self._agent_chat_messages(query):
|
||||
yield msg
|
||||
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "workflow":
|
||||
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow":
|
||||
async for msg in self._workflow_messages(query):
|
||||
yield msg
|
||||
else:
|
||||
raise errors.DifyAPIError(
|
||||
f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}"
|
||||
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
|
||||
)
|
||||
|
||||
@@ -16,14 +16,12 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求
|
||||
"""
|
||||
await query.use_model.requester.preprocess(query)
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
|
||||
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
@@ -61,7 +59,7 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
|
||||
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...plugin import context as plugin_context
|
||||
from ...provider import entities as provider_entities
|
||||
|
||||
|
||||
class SessionManager:
|
||||
@@ -41,17 +42,30 @@ class SessionManager:
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation:
|
||||
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, prompt_config: list[dict]) -> core_entities.Conversation:
|
||||
"""获取对话或创建对话"""
|
||||
|
||||
if not session.conversations:
|
||||
session.conversations = []
|
||||
|
||||
# set prompt
|
||||
prompt_messages = []
|
||||
|
||||
for prompt_message in prompt_config:
|
||||
prompt_messages.append(provider_entities.Message(**prompt_message))
|
||||
|
||||
prompt = provider_entities.Prompt(
|
||||
name="default",
|
||||
messages=prompt_messages,
|
||||
)
|
||||
|
||||
if session.using_conversation is None:
|
||||
conversation = core_entities.Conversation(
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
prompt=prompt,
|
||||
messages=[],
|
||||
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
|
||||
use_llm_model=await self.ap.model_mgr.get_model_by_uuid(
|
||||
query.pipeline_config['ai']['local-agent']['model']
|
||||
),
|
||||
use_funcs=await self.ap.tool_mgr.get_all_functions(
|
||||
plugin_enabled=True,
|
||||
),
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ...provider import entities
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
"""名称"""
|
||||
|
||||
messages: list[entities.Message]
|
||||
"""消息列表"""
|
||||
@@ -1,46 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
preregistered_loaders: list[typing.Type[PromptLoader]] = []
|
||||
|
||||
def loader_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
|
||||
cls.name = name
|
||||
preregistered_loaders.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PromptLoader(metaclass=abc.ABCMeta):
|
||||
"""Prompt加载器抽象类
|
||||
"""
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
prompts: list[entities.Prompt]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.prompts = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
"""加载Prompt,存放到prompts列表中
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_prompts(self) -> list[entities.Prompt]:
|
||||
"""获取Prompt列表
|
||||
"""
|
||||
return self.prompts
|
||||
@@ -1,39 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("full-scenario")
|
||||
class ScenarioPromptLoader(loader.PromptLoader):
|
||||
"""加载scenario目录下的json"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
for file in os.listdir("data/scenario"):
|
||||
with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
file_json = json.loads(file_str)
|
||||
messages = []
|
||||
for msg in file_json["prompt"]:
|
||||
role = 'system'
|
||||
if "role" in msg:
|
||||
role = msg['role']
|
||||
messages.append(
|
||||
llm_entities.Message(
|
||||
role=role,
|
||||
content=msg['content'],
|
||||
)
|
||||
)
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=messages
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("normal")
|
||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||
"""配置文件中的单条system prompt的prompt加载器
|
||||
"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
|
||||
for name, cnt in self.ap.provider_cfg.data['prompt'].items():
|
||||
prompt = entities.Prompt(
|
||||
name=name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=cnt
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
for file in os.listdir("data/prompts"):
|
||||
with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=file_str
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
@@ -1,56 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import app
|
||||
from . import loader
|
||||
from .loaders import single, scenario
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Prompt管理器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
loader_inst: loader.PromptLoader
|
||||
|
||||
default_prompt: str = 'default'
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
mode_name = self.ap.provider_cfg.data['prompt-mode']
|
||||
|
||||
loader_class = None
|
||||
|
||||
for loader_cls in loader.preregistered_loaders:
|
||||
if loader_cls.name == mode_name:
|
||||
loader_class = loader_cls
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
|
||||
|
||||
await self.loader_inst.initialize()
|
||||
await self.loader_inst.load()
|
||||
|
||||
def get_all_prompts(self) -> list[loader.entities.Prompt]:
|
||||
"""获取所有Prompt
|
||||
"""
|
||||
return self.loader_inst.get_prompts()
|
||||
|
||||
async def get_prompt(self, name: str) -> loader.entities.Prompt:
|
||||
"""获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name == name:
|
||||
return prompt
|
||||
|
||||
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
|
||||
"""通过前缀获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name.startswith(prefix):
|
||||
return prompt
|
||||
@@ -42,7 +42,7 @@ stages:
|
||||
zh_CN: 模型
|
||||
type: select
|
||||
required: true
|
||||
scope: llm-model
|
||||
scope: /provider/models/llm
|
||||
- name: max-round
|
||||
label:
|
||||
en_US: Max Round
|
||||
@@ -56,9 +56,14 @@ stages:
|
||||
zh_CN: 提示词
|
||||
type: array
|
||||
required: true
|
||||
default: []
|
||||
items:
|
||||
type: string
|
||||
type: object
|
||||
properties:
|
||||
role:
|
||||
type: string
|
||||
default: user
|
||||
content:
|
||||
type: string
|
||||
- name: dify-service-api
|
||||
label:
|
||||
en_US: Dify Service API
|
||||
|
||||
@@ -46,7 +46,7 @@ stages:
|
||||
zh_CN: 窗口长度(秒)
|
||||
type: integer
|
||||
required: true
|
||||
default: 10
|
||||
default: 60
|
||||
- name: limitation
|
||||
label:
|
||||
en_US: Limitation
|
||||
@@ -54,3 +54,19 @@ stages:
|
||||
type: integer
|
||||
required: true
|
||||
default: 60
|
||||
- name: strategy
|
||||
label:
|
||||
en_US: Strategy
|
||||
zh_CN: 策略
|
||||
type: select
|
||||
required: true
|
||||
default: drop
|
||||
options:
|
||||
- name: drop
|
||||
label:
|
||||
en_US: Drop
|
||||
zh_CN: 丢弃
|
||||
- name: wait
|
||||
label:
|
||||
en_US: Wait
|
||||
zh_CN: 等待
|
||||
Reference in New Issue
Block a user