Merge pull request #1245 from RockChinQ/feat/invoke-pipelines

feat: pipeline invoking
This commit is contained in:
Junyan Qin (Chin)
2025-04-03 18:05:22 +08:00
committed by GitHub
77 changed files with 663 additions and 895 deletions

View File

@@ -5,10 +5,25 @@ import datetime
import sqlalchemy
from ....core import app
from ....pipeline import stagemgr
from ....entity.persistence import pipeline as persistence_pipeline
default_stage_order = [
"GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", # 预处理器
"ConversationMessageTruncator", # 会话消息截断器
"RequireRateLimitOccupancy", # 请求速率限制占用
"MessageProcessor", # 处理器
"ReleaseRateLimitOccupancy", # 释放速率限制占用
"PostContentFilterStage", # 内容过滤后置阶段
"ResponseWrapper", # 响应包装器
"LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
]
class PipelineService:
ap: app.Application
@@ -49,7 +64,7 @@ class PipelineService:
async def create_pipeline(self, pipeline_data: dict) -> str:
pipeline_data['uuid'] = str(uuid.uuid4())
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
pipeline_data['stages'] = stagemgr.stage_order.copy()
pipeline_data['stages'] = default_stage_order.copy()
# TODO: 检查pipeline config是否完整
@@ -64,9 +79,12 @@ class PipelineService:
return pipeline_data['uuid']
async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None:
del pipeline_data['uuid']
del pipeline_data['for_version']
del pipeline_data['stages']
if 'uuid' in pipeline_data:
del pipeline_data['uuid']
if 'for_version' in pipeline_data:
del pipeline_data['for_version']
if 'stages' in pipeline_data:
del pipeline_data['stages']
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data)
)

View File

@@ -8,7 +8,7 @@ from . import entities, operator, errors
from ..config import manager as cfg_mgr
# 引入所有算子以便注册
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
class CommandManager:

View File

@@ -1,62 +0,0 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="default",
help="操作情景预设",
usage='!default\n!default set <指定情景预设为默认>'
)
class DefaultOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前所有情景预设: \n\n"
for prompt in self.ap.prompt_mgr.get_all_prompts():
content = ""
for msg in prompt.messages:
content += f" {msg.readable_str()}\n"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
reply_str += f"当前会话使用的是: {context.session.use_prompt_name}"
yield entities.CommandReturn(text=reply_str.strip())
@operator.operator_class(
name="set",
help="设置当前会话默认情景预设",
parent_class=DefaultOperator
)
class DefaultSetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else:
prompt_name = context.crt_params[0]
try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else:
context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@@ -11,16 +11,14 @@ import os
from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..provider import runnermgr
from ..config import manager as config_mgr
from ..config import settings as settings_mgr
from ..audit.center import v2 as center_mgr
from ..command import cmdmgr
from ..plugin import manager as plugin_mgr
from ..pipeline import pool
from ..pipeline import controller, stagemgr, pipelinemgr
from ..pipeline import controller, pipelinemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
@@ -53,12 +51,9 @@ class Application:
model_mgr: llm_model_mgr.ModelManager = None
prompt_mgr: llm_prompt_mgr.PromptManager = None
# TODO 移动到 pipeline 里
tool_mgr: llm_tool_mgr.ToolManager = None
runner_mgr: runnermgr.RunnerManager = None
settings_mgr: settings_mgr.SettingsManager = None
# ======= 配置管理器 =======
@@ -100,8 +95,6 @@ class Application:
ctrl: controller.Controller = None
stage_mgr: stagemgr.StageManager = None
pipeline_mgr: pipelinemgr.PipelineManager = None
ver_mgr: version_mgr.VersionManager = None
@@ -232,16 +225,8 @@ class Application:
await llm_session_mgr_inst.initialize()
self.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(self)
await llm_prompt_mgr_inst.initialize()
self.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
await llm_tool_mgr_inst.initialize()
self.tool_mgr = llm_tool_mgr_inst
runner_mgr_inst = runnermgr.RunnerManager(self)
await runner_mgr_inst.initialize()
self.runner_mgr = runner_mgr_inst
case _:
pass

View File

@@ -8,8 +8,7 @@ import asyncio
import pydantic.v1 as pydantic
from ..provider import entities as llm_entities
from ..provider.modelmgr import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.modelmgr import entities, modelmgr, requester
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
@@ -57,6 +56,15 @@ class Query(pydantic.BaseModel):
message_chain: platform_message.MessageChain
"""消息链platform收到的原始消息链"""
bot_uuid: typing.Optional[str] = None
"""机器人UUID。"""
pipeline_uuid: typing.Optional[str] = None
"""流水线UUID。"""
pipeline_config: typing.Optional[dict[str, typing.Any]] = None
"""流水线配置,由 Pipeline 在运行开始时设置。"""
adapter: msadapter.MessagePlatformAdapter
"""消息平台适配器对象单个app中可能启用了多个消息平台适配器此对象表明发起此query的适配器"""
@@ -66,7 +74,7 @@ class Query(pydantic.BaseModel):
messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器阶段设置"""
prompt: typing.Optional[sysprompt_entities.Prompt] = None
prompt: typing.Optional[llm_entities.Prompt] = None
"""情景预设内容,由前置处理器阶段设置"""
user_message: typing.Optional[llm_entities.Message] = None
@@ -75,8 +83,8 @@ class Query(pydantic.BaseModel):
variables: typing.Optional[dict[str, typing.Any]] = None
"""变量由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
use_model: typing.Optional[entities.LLMModelInfo] = None
"""使用的模型,由前置处理器阶段设置"""
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
"""使用的对话模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置"""
@@ -88,7 +96,7 @@ class Query(pydantic.BaseModel):
"""回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: "pkg.pipeline.stagemgr.StageInstContainer" = None
current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None
"""当前所处阶段"""
class Config:
@@ -118,7 +126,7 @@ class Query(pydantic.BaseModel):
class Conversation(pydantic.BaseModel):
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: sysprompt_entities.Prompt
prompt: llm_entities.Prompt
messages: list[llm_entities.Message]
@@ -126,13 +134,16 @@ class Conversation(pydantic.BaseModel):
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_model: entities.LLMModelInfo
use_llm_model: requester.RuntimeLLMModel
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
uuid: typing.Optional[str] = None
"""该对话的 uuid在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
class Config:
arbitrary_types_allowed = True
class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""

View File

@@ -6,14 +6,12 @@ from .. import stage, app
from ...utils import version, proxy, announce, platform
from ...audit.center import v2 as center_v2
from ...audit import identifier
from ...pipeline import pool, controller, stagemgr, pipelinemgr
from ...pipeline import pool, controller, pipelinemgr
from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...provider import runnermgr
from ...platform import manager as im_mgr
from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
@@ -61,10 +59,7 @@ class BuildAppStage(stage.BootingStage):
},
runtime_info={
"admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]),
"msg_source": str([
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
]),
"msg_source": str([]),
},
)
ap.ctr_mgr = center_v2_api
@@ -99,26 +94,14 @@ class BuildAppStage(stage.BootingStage):
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
await llm_prompt_mgr_inst.initialize()
ap.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
runner_mgr_inst = runnermgr.RunnerManager(ap)
await runner_mgr_inst.initialize()
ap.runner_mgr = runner_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
pipeline_mgr = pipelinemgr.PipelineManager(ap)
await pipeline_mgr.initialize()
ap.pipeline_mgr = pipeline_mgr

View File

@@ -233,3 +233,10 @@ class AsyncTaskManager:
if not wrapper.task.done() and scope in wrapper.scopes:
wrapper.task.cancel()
def cancel_task(self, task_id: int):
for wrapper in self.tasks:
if wrapper.id == task_id:
if not wrapper.task.done():
wrapper.task.cancel()
return

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import re
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -13,7 +13,7 @@ class BanSessionCheckStage(stage.PipelineStage):
仅检查query中群号或个人号是否在访问控制列表中。
"""
async def initialize(self):
async def initialize(self, pipeline_config: dict):
pass
async def process(
@@ -24,9 +24,9 @@ class BanSessionCheckStage(stage.PipelineStage):
found = False
mode = self.ap.pipeline_cfg.data['access-control']['mode']
mode = query.pipeline_config['trigger']['access-control']['mode']
sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
sess_list = query.pipeline_config['trigger']['access-control'][mode]
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
or (query.launcher_type.value == 'person' and 'person_*' in sess_list):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from ...core import app
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities
@@ -35,17 +35,18 @@ class ContentFilterStage(stage.PipelineStage):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
async def initialize(self, pipeline_config: dict):
filters_required = [
"content-ignore",
]
if self.ap.pipeline_cfg.data['check-sensitive-words']:
if pipeline_config['safety']['content-filter']['check-sensitive-words']:
filters_required.append("ban-word-filter")
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
filters_required.append("baidu-cloud-examine")
# TODO revert it
# if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
# filters_required.append("baidu-cloud-examine")
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
@@ -65,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
只要有一个不通过就不放行,只放行 PASS 的消息
"""
if not self.ap.pipeline_cfg.data['income-msg-check']:
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
@@ -73,7 +74,7 @@ class ContentFilterStage(stage.PipelineStage):
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
result = await filter.process(message)
result = await filter.process(query, message)
if result.level in [
filter_entities.ResultLevel.BLOCK,
@@ -105,7 +106,7 @@ class ContentFilterStage(stage.PipelineStage):
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
if message is None:
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
@@ -114,7 +115,7 @@ class ContentFilterStage(stage.PipelineStage):
message = message.strip()
for filter in self.filter_chain:
if filter_entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
result = await filter.process(query, message)
if result.level == filter_entities.ResultLevel.BLOCK:
return entities.StageProcessResult(

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import app, entities as core_entities
from . import entities
from ...provider import entities as llm_entities
@@ -64,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult:
"""处理消息
分为前后阶段,具体取决于 enable_stages 的值。

View File

@@ -4,6 +4,7 @@ import aiohttp
from .. import entities
from .. import filter as filter_model
from ....core import entities as core_entities
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
@@ -26,7 +27,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
) as resp:
return (await resp.json())['access_token']
async def process(self, message: str) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(

View File

@@ -3,7 +3,7 @@ import re
from .. import filter as filter_model
from .. import entities
from ....config import manager as cfg_mgr
from ....core import entities as core_entities
@filter_model.filter_class("ban-word-filter")
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self):
pass
async def process(self, message: str) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
found = False
for word in self.ap.sensitive_meta.data['words']:

View File

@@ -3,6 +3,7 @@ import re
from .. import entities
from .. import filter as filter_model
from ....core import entities as core_entities
@filter_model.filter_class("content-ignore")
@@ -15,9 +16,9 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE,
]
async def process(self, message: str) -> entities.FilterResult:
if 'prefix' in self.ap.pipeline_cfg.data['ignore-rules']:
for rule in self.ap.pipeline_cfg.data['ignore-rules']['prefix']:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
@@ -26,8 +27,8 @@ class ContentIgnore(filter_model.ContentFilter):
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
)
if 'regexp' in self.ap.pipeline_cfg.data['ignore-rules']:
for rule in self.ap.pipeline_cfg.data['ignore-rules']['regexp']:
if 'regexp' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']:
if re.search(rule, message):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,

View File

@@ -50,14 +50,23 @@ class Controller:
continue
if selected_query:
async def _process_query(selected_query):
async def _process_query(selected_query: entities.Query):
async with self.semaphore: # 总并发上限
await self.process_query(selected_query)
# find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
if bot:
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid)
if pipeline:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
self.ap.task_mgr.create_task(
_process_query(selected_query),
kind="query",
@@ -70,127 +79,6 @@ class Controller:
self.ap.logger.error(f"控制器循环出错: {e}")
self.ap.logger.error(f"Traceback: {traceback.format_exc()}")
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出
"""
if result.user_notice:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain(
platform_message.Plain(result.user_notice)
)
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(
*result.user_notice
)
await self.ap.platform_mgr.send(
query.message_event,
result.user_notice,
query.adapter
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
if result.console_notice:
self.ap.logger.info(result.console_notice)
if result.error_notice:
self.ap.logger.error(result.error_notice)
async def _execute_from_stage(
self,
stage_index: int,
query: entities.Query,
):
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
如何看懂这里为什么这么写?
去问 GPT-4:
Q1: 现在有一个责任链其中有多个stagequery对象在其中传递stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None]
如果返回的是生成器需要挨个生成result检查是否result中是否要求继续如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了就继续生成下一个result
调用后续的stage直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
Q2: 不是这样的你可能理解有误。如果我们责任链上有这些Stage
A B C D E F G
如果所有的stage都返回Result且所有Result都要求继续那么执行顺序是
A B C D E F G
现在假设C返回的是AsyncGenerator那么执行顺序是
A B C D E F G C D E F G C D E F G ...
Q3: 但是如果不止一个stage会返回生成器呢
"""
i = stage_index
while i < len(self.ap.stage_mgr.stage_containers):
stage_container = self.ap.stage_mgr.stage_containers[i]
query.current_stage = stage_container # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
async for sub_result in result:
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
i += 1
async def process_query(self, query: entities.Query):
"""处理请求
"""
try:
# ======== 触发 MessageReceived 事件 ========
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_type(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
query=query
)
)
if event_ctx.is_prevented_default():
return
self.ap.logger.debug(f"Processing query {query}")
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
finally:
self.ap.logger.debug(f"Query {query} processed")
async def run(self):
"""运行控制器
"""

View File

@@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
from ...core import app
from . import strategy
from .strategies import image, forward
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...platform.types import message as platform_message
@@ -23,8 +23,8 @@ class LongTextProcessStage(stage.PipelineStage):
strategy_impl: strategy.LongTextStrategy
async def initialize(self):
config = self.ap.platform_cfg.data['long-text-process']
async def initialize(self, pipeline_config: dict):
config = pipeline_config['output']['long-text-processing']
if config['strategy'] == 'image':
use_font = config['font-path']
try:
@@ -42,12 +42,12 @@ class LongTextProcessStage(stage.PipelineStage):
else:
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
except:
traceback.print_exc()
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
@@ -69,7 +69,7 @@ class LongTextProcessStage(stage.PipelineStage):
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']:
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult(

View File

@@ -8,6 +8,7 @@ import re
from PIL import Image, ImageDraw, ImageFont
import functools
from ....platform.types import message as platform_message
from .. import strategy as strategy_model
@@ -17,15 +18,18 @@ from ....core import entities as core_entities
@strategy_model.strategy_class("image")
class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
pass
@functools.lru_cache(maxsize=16)
def get_font(self, query: core_entities.Query):
return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8")
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
save_as='temp/{}.png'.format(int(time.time())),
query=query
)
compressed_path, size = self.compress_image(
@@ -127,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
return outfile, self.get_size(outfile)
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None):
text_str = text_str.replace("\t", " ")
@@ -142,7 +146,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
for line in lines:
# 如果长了就分割
line_width = self.text_render_font.getlength(line)
line_width = self.get_font(query).getlength(line)
self.ap.logger.debug("line_width: {}".format(line_width))
if line_width < text_width:
final_lines.append(line)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from . import truncator
from .truncators import round
@@ -14,8 +14,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
"""
trun: truncator.Truncator
async def initialize(self):
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
async def initialize(self, pipeline_config: dict):
use_method = "round"
for trun in truncator.preregistered_truncators:
if trun.name == use_method:

View File

@@ -12,7 +12,7 @@ class RoundTruncator(truncator.Truncator):
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
"""
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
max_round = query.pipeline_config['ai']['local-agent']['max-round']
temp_messages = []

View File

@@ -1,12 +1,40 @@
from __future__ import annotations
import typing
import traceback
import sqlalchemy
from ..core import app, entities
from . import entities as pipeline_entities
from ..entity.persistence import pipeline as persistence_pipeline
from . import stagemgr, stage
from . import stage
from ..platform.types import message as platform_message, events as platform_events
from ..plugin import events
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
from .preproc import preproc
from .ratelimit import ratelimit
from .msgtrun import msgtrun
class StageInstContainer():
"""阶段实例容器
"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class RuntimePipeline:
@@ -17,16 +45,146 @@ class RuntimePipeline:
pipeline_entity: persistence_pipeline.LegacyPipeline
"""流水线实体"""
stage_containers: list[stagemgr.StageInstContainer]
stage_containers: list[StageInstContainer]
"""阶段实例容器"""
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]):
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]):
self.ap = ap
self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers
async def run(self):
pass
async def run(self, query: entities.Query):
query.pipeline_config = self.pipeline_entity.config
await self.process_query(query)
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出
"""
if result.user_notice:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain(
platform_message.Plain(result.user_notice)
)
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(
*result.user_notice
)
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
result.user_notice.insert(
0,
platform_message.At(
query.message_event.sender.id
)
)
await query.adapter.reply_message(
message_source=query.message_event,
message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin']
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
if result.console_notice:
self.ap.logger.info(result.console_notice)
if result.error_notice:
self.ap.logger.error(result.error_notice)
async def _execute_from_stage(
self,
stage_index: int,
query: entities.Query,
):
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
如何看懂这里为什么这么写?
去问 GPT-4:
Q1: 现在有一个责任链其中有多个stagequery对象在其中传递stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None]
如果返回的是生成器需要挨个生成result检查是否result中是否要求继续如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了就继续生成下一个result
调用后续的stage直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
Q2: 不是这样的你可能理解有误。如果我们责任链上有这些Stage
A B C D E F G
如果所有的stage都返回Result且所有Result都要求继续那么执行顺序是
A B C D E F G
现在假设C返回的是AsyncGenerator那么执行顺序是
A B C D E F G C D E F G C D E F G ...
Q3: 但是如果不止一个stage会返回生成器呢
"""
i = stage_index
while i < len(self.stage_containers):
stage_container = self.stage_containers[i]
query.current_stage = stage_container # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
async for sub_result in result:
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
i += 1
async def process_query(self, query: entities.Query):
"""处理请求
"""
try:
# ======== 触发 MessageReceived 事件 ========
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_type(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
query=query
)
)
if event_ctx.is_prevented_default():
return
self.ap.logger.debug(f"Processing query {query}")
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
finally:
self.ap.logger.debug(f"Query {query} processed")
class PipelineManager:
@@ -70,12 +228,15 @@ class PipelineManager:
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
# initialize stage containers according to pipeline_entity.stages
stage_containers = []
stage_containers: list[StageInstContainer] = []
for stage_name in pipeline_entity.stages:
stage_containers.append(stagemgr.StageInstContainer(
stage_name=stage_name,
stage_class=self.stage_dict[stage_name]
stage_containers.append(StageInstContainer(
inst_name=stage_name,
inst=self.stage_dict[stage_name](self.ap)
))
for stage_container in stage_containers:
await stage_container.inst.initialize(pipeline_entity.config)
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
self.pipelines.append(runtime_pipeline)

View File

@@ -28,15 +28,17 @@ class QueryPool:
async def add_query(
self,
bot_uuid: str,
launcher_type: entities.LauncherTypes,
launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str],
message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain,
adapter: msadapter.MessagePlatformAdapter
adapter: msadapter.MessagePlatformAdapter,
) -> entities.Query:
async with self.condition:
query = entities.Query(
bot_uuid=bot_uuid,
query_id=self.query_id_counter,
launcher_type=launcher_type,
launcher_id=launcher_id,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import datetime
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...provider import entities as llm_entities
from ...plugin import events
@@ -33,16 +33,16 @@ class PreProcessor(stage.PipelineStage):
"""
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config['ai']['local-agent']['prompt'])
# 设置query
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
query.use_model = conversation.use_model
query.use_llm_model = conversation.use_llm_model
query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
query.variables = {
"session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}",
@@ -50,8 +50,9 @@ class PreProcessor(stage.PipelineStage):
"msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()),
}
# 检查vision是否启用没启用就删除所有图片
if not self.ap.provider_cfg.data['enable-vision'] or (self.ap.provider_cfg.data['runner'] == 'local-agent' and not query.use_model.vision_supported):
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
@@ -69,7 +70,7 @@ class PreProcessor(stage.PipelineStage):
)
plain_text += me.text
elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported):
if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_base64(me.base64)

View File

@@ -9,7 +9,9 @@ import json
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities, runnermgr
from ....provider import entities as llm_entities
from ....provider import runner as runner_module
from ....provider.runners import localagent, difysvapi, dashscopeapi
from ....plugin import events
from ....platform.types import message as platform_message
@@ -56,12 +58,6 @@ class ChatMessageHandler(handler.MessageHandler):
)
else:
if not self.ap.provider_cfg.data['enable-chat']:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
if event_ctx.event.alter is not None:
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
query.user_message.content = event_ctx.event.alter
@@ -72,7 +68,12 @@ class ChatMessageHandler(handler.MessageHandler):
try:
runner = self.ap.runner_mgr.get_runner()
for r in runner_module.preregistered_runners:
if r.name == query.pipeline_config["ai"]["runner"]["runner"]:
runner = r(self.ap, query.pipeline_config)
break
else:
raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}")
async for result in runner.run(query):
query.resp_messages.append(result)
@@ -93,10 +94,12 @@ class ChatMessageHandler(handler.MessageHandler):
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}',
user_notice='请求失败' if hide_exception_info else f'{e}',
error_notice=f'{e}',
debug_notice=traceback.format_exc()
)

View File

@@ -4,7 +4,7 @@ from ...core import app, entities as core_entities
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -23,7 +23,7 @@ class Processor(stage.PipelineStage):
chat_handler: handler.MessageHandler
async def initialize(self):
async def initialize(self, pipeline_config: dict):
self.cmd_handler = command.CommandHandler(self.ap)
self.chat_handler = chat.ChatMessageHandler(self.ap)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import app, entities as core_entities
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
@@ -31,7 +31,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
"""进入处理流程
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
@@ -46,7 +46,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
"""退出处理流程
Args:

View File

@@ -3,6 +3,7 @@ import asyncio
import time
import typing
from .. import algo
from ....core import entities as core_entities
# 固定窗口算法
class SessionContainer:
@@ -30,7 +31,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
# 加锁,找容器
container: SessionContainer = None
@@ -47,12 +48,13 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async with container.wait_lock:
# 获取窗口大小和限制
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit']
window_size = query.pipeline_config['safety']['rate-limit']['window-length']
limitation = query.pipeline_config['safety']['rate-limit']['limitation']
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
# TODO revert it
# if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
# window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
# limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
# 获取当前时间戳
now = int(time.time())
@@ -65,9 +67,9 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 如果访问次数超过了限制
if count >= limitation:
if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop':
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
return False
elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait':
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
# 等待下一窗口
await asyncio.sleep(window_size - time.time() % window_size)
@@ -84,5 +86,5 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 返回True
return True
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
pass

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import typing
from .. import entities, stagemgr, stage
from .. import entities, stage
from . import algo
from .algos import fixedwin
from ...core import entities as core_entities
@@ -18,9 +18,9 @@ class RateLimit(stage.PipelineStage):
algo: algo.ReteLimitAlgo
async def initialize(self):
async def initialize(self, pipeline_config: dict):
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
algo_name = 'fixwin'
algo_class = None
@@ -46,6 +46,7 @@ class RateLimit(stage.PipelineStage):
"""
if stage_inst_name == "RequireRateLimitOccupancy":
if await self.algo.require_access(
query,
query.launcher_type.value,
query.launcher_id,
):
@@ -62,6 +63,7 @@ class RateLimit(stage.PipelineStage):
)
elif stage_inst_name == "ReleaseRateLimitOccupancy":
await self.algo.release_access(
query,
query.launcher_type.value,
query.launcher_id,
)

View File

@@ -5,8 +5,10 @@ import asyncio
from ...core import app
from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -19,8 +21,8 @@ class SendResponseBackStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
random_range = (self.ap.platform_cfg.data['force-delay']['min'], self.ap.platform_cfg.data['force-delay']['max'])
random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max'])
random_delay = random.uniform(*random_range)
@@ -31,10 +33,20 @@ class SendResponseBackStage(stage.PipelineStage):
await asyncio.sleep(random_delay)
await self.ap.platform_mgr.send(
query.message_event,
query.resp_message_chain[-1],
adapter=query.adapter
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
query.resp_message_chain[-1].insert(
0,
platform_message.At(
query.message_event.sender.id
)
)
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
await query.adapter.reply_message(
message_source=query.message_event,
message=query.resp_message_chain[-1],
quote_origin=quote_origin
)
return entities.StageProcessResult(

View File

@@ -5,7 +5,7 @@ from ...core import app
from . import entities as rule_entities, rule
from .rules import atbot, prefix, regexp, random
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -20,7 +20,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
rule_matchers: list[rule.GroupRespondRule]
"""检查器实例"""
async def initialize(self):
async def initialize(self, pipeline_config: dict):
"""初始化检查器
"""
@@ -39,12 +39,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
new_query=query
)
rules = self.ap.pipeline_cfg.data['respond-rules']
rules = query.pipeline_config['trigger']['group-respond-rules']
use_rule = rules['default']
use_rule = rules
if str(query.launcher_id) in rules:
use_rule = rules[str(query.launcher_id)]
# TODO revert it
# if str(query.launcher_id) in rules:
# use_rule = rules[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)

View File

@@ -28,7 +28,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
async def initialize(self, pipeline_config: dict):
"""初始化
"""
pass

View File

@@ -1,71 +0,0 @@
from __future__ import annotations
from ..core import app
from . import stage
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
from .preproc import preproc
from .ratelimit import ratelimit
from .msgtrun import msgtrun
# 请求处理阶段顺序
stage_order = [
"GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", # 预处理器
"ConversationMessageTruncator", # 会话消息截断器
"RequireRateLimitOccupancy", # 请求速率限制占用
"MessageProcessor", # 处理器
"ReleaseRateLimitOccupancy", # 释放速率限制占用
"PostContentFilterStage", # 内容过滤后置阶段
"ResponseWrapper", # 响应包装器
"LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
]
class StageInstContainer():
"""阶段实例容器
"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class StageManager:
ap: app.Application
stage_containers: list[StageInstContainer]
def __init__(self, ap: app.Application):
self.ap = ap
self.stage_containers = []
async def initialize(self):
"""初始化
"""
for name, cls in stage.preregistered_stages.items():
self.stage_containers.append(StageInstContainer(
inst_name=name,
inst=cls(self.ap)
))
for stage_containers in self.stage_containers:
await stage_containers.inst.initialize()
# 按照 stage_order 排序
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))

View File

@@ -5,7 +5,7 @@ import typing
from ...core import app, entities as core_entities
from .. import entities
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...plugin import events
@@ -22,7 +22,7 @@ class ResponseWrapper(stage.PipelineStage):
- resp_message_chain
"""
async def initialize(self):
async def initialize(self, pipeline_config: dict):
pass
async def process(
@@ -110,7 +110,7 @@ class ResponseWrapper(stage.PipelineStage):
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
if query.pipeline_config['output']['misc']['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(

View File

@@ -50,6 +50,41 @@ class RuntimeBot:
self.adapter = adapter
self.task_context = taskmgr.TaskContext()
async def initialize(self):
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter):
await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter,
)
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter):
await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter,
)
self.adapter.register_listener(
platform_events.FriendMessage,
on_friend_message
)
self.adapter.register_listener(
platform_events.GroupMessage,
on_group_message
)
async def run(self):
async def exception_wrapper():
@@ -78,14 +113,16 @@ class RuntimeBot:
async def shutdown(self):
await self.adapter.kill()
self.ap.task_mgr.cancel_task(self.task_wrapper.id)
# 控制QQ消息输入输出的类
class PlatformManager:
# adapter: msadapter.MessageSourceAdapter = None
adapters: list[msadapter.MessagePlatformAdapter] = []
adapters: list[msadapter.MessagePlatformAdapter] = [] # deprecated
message_platform_adapter_components: list[engine.Component] = []
message_platform_adapter_components: list[engine.Component] = [] # deprecated
# ====== 4.0 ======
ap: app.Application = None
@@ -135,49 +172,20 @@ class PlatformManager:
bot_entity = persistence_bot.Bot(**bot_entity._mapping)
elif isinstance(bot_entity, dict):
bot_entity = persistence_bot.Bot(**bot_entity)
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
adapter_inst = self.adapter_dict[bot_entity.adapter](
bot_entity.adapter_config,
self.ap
)
adapter_inst.register_listener(
platform_events.FriendMessage,
on_friend_message
)
adapter_inst.register_listener(
platform_events.GroupMessage,
on_group_message
)
runtime_bot = RuntimeBot(
ap=self.ap,
bot_entity=bot_entity,
adapter=adapter_inst
)
await runtime_bot.initialize()
self.bots.append(runtime_bot)
return runtime_bot
@@ -209,50 +217,36 @@ class PlatformManager:
return None
async def write_back_config(self, adapter_name: str, adapter_inst: msadapter.MessagePlatformAdapter, config: dict):
index = -2
# index = -2
for i, adapter in enumerate(self.adapters):
if adapter == adapter_inst:
index = i
break
# for i, adapter in enumerate(self.adapters):
# if adapter == adapter_inst:
# index = i
# break
if index == -2:
raise Exception('平台适配器未找到')
# if index == -2:
# raise Exception('平台适配器未找到')
# 只修改启用的适配器
real_index = -1
# # 只修改启用的适配器
# real_index = -1
for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
if adapter['enable']:
index -= 1
if index == -1:
real_index = i
break
# for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
# if adapter['enable']:
# index -= 1
# if index == -1:
# real_index = i
# break
new_cfg = {
'adapter': adapter_name,
'enable': True,
**config
}
self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
await self.ap.platform_cfg.dump_config()
# new_cfg = {
# 'adapter': adapter_name,
# 'enable': True,
# **config
# }
# self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
# await self.ap.platform_cfg.dump_config()
async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessagePlatformAdapter):
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage):
msg.insert(
0,
platform_message.At(
event.sender.id
)
)
await adapter.reply_message(
event,
msg,
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False
)
# TODO implement this
pass
async def run(self):
# This method will only be called when the application launching
@@ -264,4 +258,4 @@ class PlatformManager:
for bot in self.bots:
if bot.enable:
await bot.shutdown()
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)

View File

@@ -238,4 +238,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
await self.bot._server_app.run_task(**self.config)
async def kill(self) -> bool:
# Current issue: existing connection will not be closed
# self.should_shutdown = True
return False

View File

@@ -28,9 +28,11 @@ spec:
label:
en_US: Intents
zh_CN: 权限
type: array[string]
type: array
required: true
default: []
items:
type: string
execution:
python:
path: ./qqbotpy.py

View File

@@ -222,10 +222,10 @@ class EventContext:
Args:
message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
"""
await self.host.ap.platform_mgr.send(
event=self.event.query.message_event,
msg=message_chain,
adapter=self.event.query.adapter,
# TODO 添加 at_sender 和 quote_origin 参数
await self.event.query.adapter.reply_message(
message_source=self.event.query.message_event,
message=message_chain
)
async def send_message(

View File

@@ -4,6 +4,8 @@ import typing
import enum
import pydantic.v1 as pydantic
from pkg.provider import entities
from ..platform.types import message as platform_message
@@ -124,3 +126,13 @@ class Message(pydantic.BaseModel):
mc.insert(0, platform_message.Plain(prefix_text))
return platform_message.MessageChain(mc)
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
"""名称"""
messages: list[entities.Message]
"""消息列表"""

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import typing
import sqlalchemy
import pydantic.v1 as pydantic
from . import entities, requester
from ...core import app
@@ -16,23 +17,6 @@ from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcm
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
class RuntimeLLMModel:
"""运行时模型"""
model_entity: persistence_model.LLMModel
"""模型数据"""
token_mgr: token.TokenManager
"""api key管理器"""
requester: requester.LLMAPIRequester
"""请求器实例"""
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
class ModelManager:
"""模型管理器"""
@@ -47,7 +31,7 @@ class ModelManager:
ap: app.Application
llm_models: list[RuntimeLLMModel]
llm_models: list[requester.RuntimeLLMModel]
requester_components: list[engine.Component]
@@ -99,16 +83,20 @@ class ModelManager:
elif isinstance(model_info, dict):
model_info = persistence_model.LLMModel(**model_info)
runtime_llm_model = RuntimeLLMModel(
requester_inst = self.requester_dict[model_info.requester](
ap=self.ap,
config=model_info.requester_config
)
await requester_inst.initialize()
runtime_llm_model = requester.RuntimeLLMModel(
model_entity=model_info,
token_mgr=token.TokenManager(
name=model_info.uuid,
tokens=model_info.api_keys,
),
requester=self.requester_dict[model_info.requester](
ap=self.ap,
config=model_info.requester_config
)
requester=requester_inst
)
self.llm_models.append(runtime_llm_model)

View File

@@ -6,8 +6,27 @@ import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from . import entities as modelmgr_entities
from ..tools import entities as tools_entities
from ...entity.persistence import model as persistence_model
from . import token
class RuntimeLLMModel:
"""运行时模型"""
model_entity: persistence_model.LLMModel
"""模型数据"""
token_mgr: token.TokenManager
"""api key管理器"""
requester: LLMAPIRequester
"""请求器实例"""
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
class LLMAPIRequester(metaclass=abc.ABCMeta):
@@ -31,21 +50,11 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def initialize(self):
pass
async def preprocess(
self,
query: core_entities.Query,
):
"""预处理
在这里处理特定API对Query对象的兼容性问题。
"""
pass
@abc.abstractmethod
async def call(
async def invoke_llm(
self,
query: core_entities.Query,
model: modelmgr_entities.LLMModelInfo,
model: RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
@@ -53,7 +62,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
"""调用API
Args:
model (modelmgr_entities.LLMModelInfo): 使用的模型信息
model (RuntimeLLMModel): 使用的模型信息
messages (typing.List[llm_entities.Message]): 消息对象列表
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.

View File

@@ -24,16 +24,16 @@ class AnthropicMessages(requester.LLMAPIRequester):
client: anthropic.AsyncAnthropic
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.anthropic.com/v1',
'base_url': 'https://api.anthropic.com/v1',
'timeout': 120,
}
async def initialize(self):
httpx_client = anthropic._base_client.AsyncHttpxClientWrapper(
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
base_url=self.requester_cfg['base_url'],
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=typing.cast(httpx.Timeout, self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout']),
timeout=typing.cast(httpx.Timeout, self.requester_cfg['timeout']),
limits=anthropic._constants.DEFAULT_CONNECTION_LIMITS,
follow_redirects=True,
trust_env=True,
@@ -44,17 +44,18 @@ class AnthropicMessages(requester.LLMAPIRequester):
http_client=httpx_client,
)
async def call(
async def invoke_llm(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = model.name if model.model_name is None else model.model_name
args = extra_args.copy()
args["model"] = model.model_entity.name
# 处理消息

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: Anthropic
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class BailianChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 阿里云百炼
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -26,7 +26,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
"base-url": "https://api.openai.com/v1",
"base_url": "https://api.openai.com/v1",
"timeout": 120,
}
@@ -34,7 +34,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
self.client = openai.AsyncClient(
api_key="",
base_url=self.requester_cfg["base-url"],
base_url=self.requester_cfg["base_url"],
timeout=self.requester_cfg["timeout"],
http_client=httpx.AsyncClient(
trust_env=True, timeout=self.requester_cfg["timeout"]
@@ -51,7 +51,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
self,
chat_completion: chat_completion.ChatCompletion,
) -> llm_entities.Message:
chatcmpl_message = chat_completion.choices[0].message.dict()
chatcmpl_message = chat_completion.choices[0].message.model_dump()
# 确保 role 字段存在且不为 None
if "role" not in chatcmpl_message or chatcmpl_message["role"] is None:
@@ -65,16 +65,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg["args"].copy()
args["model"] = (
use_model.name if use_model.model_name is None else use_model.model_name
)
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
@@ -104,10 +102,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
return message
async def call(
async def invoke_llm(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: OpenAI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -13,7 +13,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Deepseek ChatCompletion API 请求器"""
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.deepseek.com',
'base_url': 'https://api.deepseek.com',
'timeout': 120,
}
@@ -21,14 +21,14 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 深度求索
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -18,7 +18,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Gitee AI ChatCompletions API 请求器"""
default_config: dict[str, typing.Any] = {
'base-url': 'https://ai.gitee.com/v1',
'base_url': 'https://ai.gitee.com/v1',
'timeout': 120,
}
@@ -26,14 +26,14 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: Gitee AI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'http://127.0.0.1:1234/v1',
'base_url': 'http://127.0.0.1:1234/v1',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: LM Studio
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -15,7 +15,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Moonshot ChatCompletion API 请求器"""
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.moonshot.cn/v1',
'base_url': 'https://api.moonshot.cn/v1',
'timeout': 120,
}
@@ -23,14 +23,14 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 月之暗面
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -22,35 +22,38 @@ REQUESTER_NAME: str = "ollama-chat"
class OllamaChatCompletions(requester.LLMAPIRequester):
"""Ollama平台 ChatCompletion API请求器"""
client: ollama.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'http://127.0.0.1:11434',
'timeout': 120,
"base_url": "http://127.0.0.1:11434",
"timeout": 120,
}
async def initialize(self):
os.environ['OLLAMA_HOST'] = self.requester_cfg['base-url']
self.client = ollama.AsyncClient(
timeout=self.requester_cfg['timeout']
)
os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"]
self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"])
async def _req(self,
args: dict,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
return await self.client.chat(
**args
)
async def _req(
self,
args: dict,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
return await self.client.chat(**args)
async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo,
user_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {}) -> llm_entities.Message:
args: Any = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
async def _closure(
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
user_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
args = extra_args.copy()
args["model"] = use_model.model_entity.name
messages: list[dict] = req_messages.copy()
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
if "content" in msg and isinstance(msg["content"], list):
text_content: list = []
image_urls: list = []
for me in msg["content"]:
@@ -58,12 +61,16 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
text_content.append(me["text"])
elif me["type"] == "image_base64":
image_urls.append(me["image_base64"])
msg["content"] = "\n".join(text_content)
msg["images"] = [url.split(',')[1] for url in image_urls]
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
for tool_call in msg['tool_calls']:
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
msg["images"] = [url.split(",")[1] for url in image_urls]
if (
"tool_calls" in msg
): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
for tool_call in msg["tool_calls"]:
tool_call["function"]["arguments"] = json.loads(
tool_call["function"]["arguments"]
)
args["messages"] = messages
args["tools"] = []
@@ -77,8 +84,8 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
return message
async def _make_msg(
self,
chat_completions: ollama.ChatResponse) -> llm_entities.Message:
self, chat_completions: ollama.ChatResponse
) -> llm_entities.Message:
message: ollama.Message = chat_completions.message
if message is None:
raise ValueError("chat_completions must contain a 'message' field")
@@ -86,43 +93,51 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
ret_msg: llm_entities.Message = None
if message.content is not None:
ret_msg = llm_entities.Message(
role="assistant",
content=message.content
)
ret_msg = llm_entities.Message(role="assistant", content=message.content)
if message.tool_calls is not None and len(message.tool_calls) > 0:
tool_calls: list[llm_entities.ToolCall] = []
for tool_call in message.tool_calls:
tool_calls.append(llm_entities.ToolCall(
id=uuid.uuid4().hex,
type="function",
function=llm_entities.FunctionCall(
name=tool_call.function.name,
arguments=json.dumps(tool_call.function.arguments)
tool_calls.append(
llm_entities.ToolCall(
id=uuid.uuid4().hex,
type="function",
function=llm_entities.FunctionCall(
name=tool_call.function.name,
arguments=json.dumps(tool_call.function.arguments),
),
)
))
)
ret_msg.tool_calls = tool_calls
return ret_msg
async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
async def invoke_llm(
self,
query: core_entities.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
req_messages: list = []
for m in messages:
msg_dict: dict = m.dict(exclude_none=True)
content: Any = msg_dict.get("content")
if isinstance(content, list):
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
if all(
isinstance(part, dict) and part.get("type") == "text"
for part in content
):
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(query, req_messages, model, funcs, extra_args)
return await self._closure(
query=query,
req_messages=req_messages,
use_model=model,
use_funcs=funcs,
extra_args=extra_args,
)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
raise errors.RequesterError("请求超时")

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: Ollama
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.siliconflow.cn/v1',
'base_url': 'https://api.siliconflow.cn/v1',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 硅基流动
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://ark.cn-beijing.volces.com/api/v3',
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 火山方舟
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.x.ai/v1',
'base_url': 'https://api.x.ai/v1',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: xAI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://open.bigmodel.cn/api/paas/v4',
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 智谱 AI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -27,11 +27,11 @@ class RequestRunner(abc.ABC):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
pipeline_config: dict
async def initialize(self):
pass
def __init__(self, ap: app.Application, pipeline_config: dict):
self.ap = ap
self.pipeline_config = pipeline_config
@abc.abstractmethod
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:

View File

@@ -1,30 +0,0 @@
from __future__ import annotations
from . import runner
from ..core import app
from .runners import localagent
from .runners import difysvapi
from .runners import dashscopeapi
class RunnerManager:
ap: app.Application
using_runner: runner.RequestRunner
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
for r in runner.preregistered_runners:
if r.name == self.ap.provider_cfg.data['runner']:
self.using_runner = r(self.ap)
await self.using_runner.initialize()
break
else:
raise ValueError(f"未找到请求运行器: {self.ap.provider_cfg.data['runner']}")
def get_runner(self) -> runner.RequestRunner:
return self.using_runner

View File

@@ -8,7 +8,7 @@ import re
import dashscope
from .. import runner
from ...core import entities as core_entities
from ...core import app, entities as core_entities
from .. import entities as llm_entities
from ...utils import image
@@ -29,12 +29,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
app_id: str # 应用ID
api_key: str # API Key
references_quote: str # 引用资料提示当展示回答来源功能开启时这个变量会作为引用资料名前的提示可在provider.json中配置
biz_params: dict = {} # 工作流应用参数(仅在工作流应用中生效)
async def initialize(self):
def __init__(self, ap: app.Application, pipeline_config: dict):
"""初始化"""
self.ap = ap
self.pipeline_config = pipeline_config
valid_app_types = ["agent", "workflow"]
self.app_type = self.ap.provider_cfg.data["dashscope-app-api"]["app-type"]
self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"]
#检查配置文件中使用的应用类型是否支持
if (self.app_type not in valid_app_types):
raise DashscopeAPIError(
@@ -42,10 +44,9 @@ class DashScopeAPIRunner(runner.RequestRunner):
)
#初始化Dashscope 参数配置
self.app_id = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["app-id"]
self.api_key = self.ap.provider_cfg.data["dashscope-app-api"]["api-key"]
self.references_quote = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["references_quote"]
self.biz_params = self.ap.provider_cfg.data["dashscope-app-api"]["workflow"]["biz_params"]
self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"]
self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"]
self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"]
def _replace_references(self, text, references_dict):
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
@@ -169,7 +170,6 @@ class DashScopeAPIRunner(runner.RequestRunner):
plain_text, image_ids = await self._preprocess_user_message(query)
biz_params = {}
biz_params.update(self.biz_params)
biz_params.update(query.variables)
#发送对话请求
@@ -220,21 +220,19 @@ class DashScopeAPIRunner(runner.RequestRunner):
content=pending_content,
)
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行"""
if self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "agent":
if self.app_type == "agent":
async for msg in self._agent_messages(query):
yield msg
elif self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "workflow":
elif self.app_type == "workflow":
async for msg in self._workflow_messages(query):
yield msg
else:
raise DashscopeAPIError(
f"不支持的 Dashscope 应用类型: {self.ap.provider_cfg.data['dashscope-app-api']['app-type']}"
f"不支持的 Dashscope 应用类型: {self.app_type}"
)

View File

@@ -10,7 +10,7 @@ import datetime
import aiohttp
from .. import runner
from ...core import entities as core_entities
from ...core import app, entities as core_entities
from .. import entities as llm_entities
from ...utils import image
@@ -23,24 +23,24 @@ class DifyServiceAPIRunner(runner.RequestRunner):
dify_client: client.AsyncDifyServiceClient
async def initialize(self):
"""初始化"""
def __init__(self, ap: app.Application, pipeline_config: dict):
self.ap = ap
self.pipeline_config = pipeline_config
valid_app_types = ["chat", "agent", "workflow"]
if (
self.ap.provider_cfg.data["dify-service-api"]["app-type"]
self.pipeline_config["ai"]["dify-service-api"]["app-type"]
not in valid_app_types
):
raise errors.DifyAPIError(
f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}"
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
)
api_key = self.ap.provider_cfg.data["dify-service-api"][
self.ap.provider_cfg.data["dify-service-api"]["app-type"]
]["api-key"]
api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"]
self.dify_client = client.AsyncDifyServiceClient(
api_key=api_key,
base_url=self.ap.provider_cfg.data["dify-service-api"]["base-url"],
base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"],
)
def _try_convert_thinking(self, resp_text: str) -> str:
@@ -48,13 +48,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if not resp_text.startswith("<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>"):
return resp_text
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "original":
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original":
return resp_text
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "remove":
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove":
return re.sub(r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>', '', resp_text, flags=re.DOTALL)
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "plain":
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain":
pattern = r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>(.*?)</details>'
thinking_text = re.search(pattern, resp_text, flags=re.DOTALL)
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
@@ -121,7 +121,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
conversation_id=cov_id,
files=files,
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
):
self.ap.logger.debug("dify-chat-chunk: " + str(chunk))
@@ -177,7 +177,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
response_mode="streaming",
conversation_id=cov_id,
files=files,
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
):
self.ap.logger.debug("dify-agent-chunk: " + str(chunk))
@@ -264,7 +264,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
inputs=inputs,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
files=files,
timeout=self.ap.provider_cfg.data["dify-service-api"]["workflow"]["timeout"],
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
):
self.ap.logger.debug("dify-workflow-chunk: " + str(chunk))
if chunk["event"] in ignored_events:
@@ -301,11 +301,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
msg = llm_entities.Message(
role="assistant",
content=chunk["data"]["outputs"][
self.ap.provider_cfg.data["dify-service-api"]["workflow"][
"output-key"
]
],
content=chunk["data"]["outputs"]["summary"],
)
yield msg
@@ -314,16 +310,16 @@ class DifyServiceAPIRunner(runner.RequestRunner):
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
if self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "chat":
if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat":
async for msg in self._chat_messages(query):
yield msg
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "agent":
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent":
async for msg in self._agent_chat_messages(query):
yield msg
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "workflow":
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow":
async for msg in self._workflow_messages(query):
yield msg
else:
raise errors.DifyAPIError(
f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}"
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
)

View File

@@ -16,14 +16,12 @@ class LocalAgentRunner(runner.RequestRunner):
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
yield msg
@@ -61,7 +59,7 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
yield msg

View File

@@ -4,6 +4,7 @@ import asyncio
from ...core import app, entities as core_entities
from ...plugin import context as plugin_context
from ...provider import entities as provider_entities
class SessionManager:
@@ -41,17 +42,30 @@ class SessionManager:
self.session_list.append(session)
return session
async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation:
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, prompt_config: list[dict]) -> core_entities.Conversation:
"""获取对话或创建对话"""
if not session.conversations:
session.conversations = []
# set prompt
prompt_messages = []
for prompt_message in prompt_config:
prompt_messages.append(provider_entities.Message(**prompt_message))
prompt = provider_entities.Prompt(
name="default",
messages=prompt_messages,
)
if session.using_conversation is None:
conversation = core_entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
prompt=prompt,
messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
use_llm_model=await self.ap.model_mgr.get_model_by_uuid(
query.pipeline_config['ai']['local-agent']['model']
),
use_funcs=await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
),

View File

@@ -1,16 +0,0 @@
from __future__ import annotations
import typing
import pydantic.v1 as pydantic
from ...provider import entities
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
"""名称"""
messages: list[entities.Message]
"""消息列表"""

View File

@@ -1,46 +0,0 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
preregistered_loaders: list[typing.Type[PromptLoader]] = []
def loader_class(name: str):
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
cls.name = name
preregistered_loaders.append(cls)
return cls
return decorator
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
name: str
ap: app.Application
prompts: list[entities.Prompt]
def __init__(self, ap: app.Application):
self.ap = ap
self.prompts = []
async def initialize(self):
pass
@abc.abstractmethod
async def load(self):
"""加载Prompt存放到prompts列表中
"""
raise NotImplementedError
def get_prompts(self) -> list[entities.Prompt]:
"""获取Prompt列表
"""
return self.prompts

View File

@@ -1,39 +0,0 @@
from __future__ import annotations
import json
import os
from .. import loader
from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("full-scenario")
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""
async def load(self):
"""加载Prompt
"""
for file in os.listdir("data/scenario"):
with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
file_json = json.loads(file_str)
messages = []
for msg in file_json["prompt"]:
role = 'system'
if "role" in msg:
role = msg['role']
messages.append(
llm_entities.Message(
role=role,
content=msg['content'],
)
)
prompt = entities.Prompt(
name=file_name,
messages=messages
)
self.prompts.append(prompt)

View File

@@ -1,43 +0,0 @@
from __future__ import annotations
import os
from .. import loader
from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("normal")
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""
async def load(self):
"""加载Prompt
"""
for name, cnt in self.ap.provider_cfg.data['prompt'].items():
prompt = entities.Prompt(
name=name,
messages=[
llm_entities.Message(
role='system',
content=cnt
)
]
)
self.prompts.append(prompt)
for file in os.listdir("data/prompts"):
with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
prompt = entities.Prompt(
name=file_name,
messages=[
llm_entities.Message(
role='system',
content=file_str
)
]
)
self.prompts.append(prompt)

View File

@@ -1,56 +0,0 @@
from __future__ import annotations
from ...core import app
from . import loader
from .loaders import single, scenario
class PromptManager:
"""Prompt管理器
"""
ap: app.Application
loader_inst: loader.PromptLoader
default_prompt: str = 'default'
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
mode_name = self.ap.provider_cfg.data['prompt-mode']
loader_class = None
for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
loader_class = loader_cls
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()
def get_all_prompts(self) -> list[loader.entities.Prompt]:
"""获取所有Prompt
"""
return self.loader_inst.get_prompts()
async def get_prompt(self, name: str) -> loader.entities.Prompt:
"""获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name == name:
return prompt
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
"""通过前缀获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name.startswith(prefix):
return prompt

View File

@@ -42,7 +42,7 @@ stages:
zh_CN: 模型
type: select
required: true
scope: llm-model
scope: /provider/models/llm
- name: max-round
label:
en_US: Max Round
@@ -56,9 +56,14 @@ stages:
zh_CN: 提示词
type: array
required: true
default: []
items:
type: string
type: object
properties:
role:
type: string
default: user
content:
type: string
- name: dify-service-api
label:
en_US: Dify Service API

View File

@@ -46,7 +46,7 @@ stages:
zh_CN: 窗口长度(秒)
type: integer
required: true
default: 10
default: 60
- name: limitation
label:
en_US: Limitation
@@ -54,3 +54,19 @@ stages:
type: integer
required: true
default: 60
- name: strategy
label:
en_US: Strategy
zh_CN: 策略
type: select
required: true
default: drop
options:
- name: drop
label:
en_US: Drop
zh_CN: 丢弃
- name: wait
label:
en_US: Wait
zh_CN: 等待