feat: preliminarily implement pipeline invoking

This commit is contained in:
Junyan Qin
2025-03-29 17:50:45 +08:00
parent d01eadc70f
commit 9f15ab5000
57 changed files with 384 additions and 421 deletions

View File

@@ -5,10 +5,25 @@ import datetime
import sqlalchemy
from ....core import app
from ....pipeline import stagemgr
from ....entity.persistence import pipeline as persistence_pipeline
default_stage_order = [
"GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", # 预处理器
"ConversationMessageTruncator", # 会话消息截断器
"RequireRateLimitOccupancy", # 请求速率限制占用
"MessageProcessor", # 处理器
"ReleaseRateLimitOccupancy", # 释放速率限制占用
"PostContentFilterStage", # 内容过滤后置阶段
"ResponseWrapper", # 响应包装器
"LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
]
class PipelineService:
ap: app.Application
@@ -49,7 +64,7 @@ class PipelineService:
async def create_pipeline(self, pipeline_data: dict) -> str:
pipeline_data['uuid'] = str(uuid.uuid4())
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
pipeline_data['stages'] = stagemgr.stage_order.copy()
pipeline_data['stages'] = default_stage_order.copy()
# TODO: 检查pipeline config是否完整

View File

@@ -13,14 +13,13 @@ from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..provider import runnermgr
from ..config import manager as config_mgr
from ..config import settings as settings_mgr
from ..audit.center import v2 as center_mgr
from ..command import cmdmgr
from ..plugin import manager as plugin_mgr
from ..pipeline import pool
from ..pipeline import controller, stagemgr, pipelinemgr
from ..pipeline import controller, pipelinemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
@@ -53,12 +52,12 @@ class Application:
model_mgr: llm_model_mgr.ModelManager = None
# TODO 移动到 pipeline 里
prompt_mgr: llm_prompt_mgr.PromptManager = None
# TODO 移动到 pipeline 里
tool_mgr: llm_tool_mgr.ToolManager = None
runner_mgr: runnermgr.RunnerManager = None
settings_mgr: settings_mgr.SettingsManager = None
# ======= 配置管理器 =======
@@ -100,8 +99,6 @@ class Application:
ctrl: controller.Controller = None
stage_mgr: stagemgr.StageManager = None
pipeline_mgr: pipelinemgr.PipelineManager = None
ver_mgr: version_mgr.VersionManager = None
@@ -239,9 +236,5 @@ class Application:
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
await llm_tool_mgr_inst.initialize()
self.tool_mgr = llm_tool_mgr_inst
runner_mgr_inst = runnermgr.RunnerManager(self)
await runner_mgr_inst.initialize()
self.runner_mgr = runner_mgr_inst
case _:
pass

View File

@@ -8,7 +8,7 @@ import asyncio
import pydantic.v1 as pydantic
from ..provider import entities as llm_entities
from ..provider.modelmgr import entities
from ..provider.modelmgr import entities, modelmgr, requester
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
@@ -57,6 +57,9 @@ class Query(pydantic.BaseModel):
message_chain: platform_message.MessageChain
"""消息链platform收到的原始消息链"""
bot_uuid: typing.Optional[str] = None
"""机器人UUID。"""
pipeline_uuid: typing.Optional[str] = None
"""流水线UUID。"""
@@ -81,8 +84,8 @@ class Query(pydantic.BaseModel):
variables: typing.Optional[dict[str, typing.Any]] = None
"""变量由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
use_model: typing.Optional[entities.LLMModelInfo] = None
"""使用的模型,由前置处理器阶段设置"""
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
"""使用的对话模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置"""
@@ -94,7 +97,7 @@ class Query(pydantic.BaseModel):
"""回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: "pkg.pipeline.stagemgr.StageInstContainer" = None
current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None
"""当前所处阶段"""
class Config:
@@ -132,13 +135,16 @@ class Conversation(pydantic.BaseModel):
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_model: entities.LLMModelInfo
use_llm_model: requester.RuntimeLLMModel
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
uuid: typing.Optional[str] = None
"""该对话的 uuid在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
class Config:
arbitrary_types_allowed = True
class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""

View File

@@ -6,14 +6,13 @@ from .. import stage, app
from ...utils import version, proxy, announce, platform
from ...audit.center import v2 as center_v2
from ...audit import identifier
from ...pipeline import pool, controller, stagemgr, pipelinemgr
from ...pipeline import pool, controller, pipelinemgr
from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...provider import runnermgr
from ...platform import manager as im_mgr
from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
@@ -61,10 +60,7 @@ class BuildAppStage(stage.BootingStage):
},
runtime_info={
"admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]),
"msg_source": str([
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
]),
"msg_source": str([]),
},
)
ap.ctr_mgr = center_v2_api
@@ -107,18 +103,10 @@ class BuildAppStage(stage.BootingStage):
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
runner_mgr_inst = runnermgr.RunnerManager(ap)
await runner_mgr_inst.initialize()
ap.runner_mgr = runner_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
pipeline_mgr = pipelinemgr.PipelineManager(ap)
await pipeline_mgr.initialize()
ap.pipeline_mgr = pipeline_mgr

View File

@@ -233,3 +233,10 @@ class AsyncTaskManager:
if not wrapper.task.done() and scope in wrapper.scopes:
wrapper.task.cancel()
def cancel_task(self, task_id: int):
for wrapper in self.tasks:
if wrapper.id == task_id:
if not wrapper.task.done():
wrapper.task.cancel()
return

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import re
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -13,7 +13,7 @@ class BanSessionCheckStage(stage.PipelineStage):
仅检查query中群号或个人号是否在访问控制列表中。
"""
async def initialize(self):
async def initialize(self, pipeline_config: dict):
pass
async def process(

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from ...core import app
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities
@@ -35,7 +35,7 @@ class ContentFilterStage(stage.PipelineStage):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
async def initialize(self, pipeline_config: dict):
filters_required = [
"content-ignore",

View File

@@ -54,9 +54,13 @@ class Controller:
async def _process_query(selected_query: entities.Query):
async with self.semaphore: # 总并发上限
# find pipeline
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(selected_query.pipeline_uuid)
if pipeline:
await pipeline.run(selected_query)
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
if bot:
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid)
if pipeline:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()

View File

@@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
from ...core import app
from . import strategy
from .strategies import image, forward
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...platform.types import message as platform_message
@@ -23,8 +23,8 @@ class LongTextProcessStage(stage.PipelineStage):
strategy_impl: strategy.LongTextStrategy
async def initialize(self):
config = self.ap.platform_cfg.data['long-text-process']
async def initialize(self, pipeline_config: dict):
config = pipeline_config['output']['long-text-processing']
if config['strategy'] == 'image':
use_font = config['font-path']
try:
@@ -42,12 +42,12 @@ class LongTextProcessStage(stage.PipelineStage):
else:
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
except:
traceback.print_exc()
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
@@ -69,7 +69,7 @@ class LongTextProcessStage(stage.PipelineStage):
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']:
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult(

View File

@@ -8,6 +8,7 @@ import re
from PIL import Image, ImageDraw, ImageFont
import functools
from ....platform.types import message as platform_message
from .. import strategy as strategy_model
@@ -17,15 +18,18 @@ from ....core import entities as core_entities
@strategy_model.strategy_class("image")
class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
pass
@functools.lru_cache(maxsize=16)
def get_font(self, query: core_entities.Query):
return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8")
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
save_as='temp/{}.png'.format(int(time.time())),
query=query
)
compressed_path, size = self.compress_image(
@@ -127,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
return outfile, self.get_size(outfile)
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None):
text_str = text_str.replace("\t", " ")
@@ -142,7 +146,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
for line in lines:
# 如果长了就分割
line_width = self.text_render_font.getlength(line)
line_width = self.get_font(query).getlength(line)
self.ap.logger.debug("line_width: {}".format(line_width))
if line_width < text_width:
final_lines.append(line)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from . import truncator
from .truncators import round
@@ -14,7 +14,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
"""
trun: truncator.Truncator
async def initialize(self):
async def initialize(self, pipeline_config: dict):
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
for trun in truncator.preregistered_truncators:

View File

@@ -8,10 +8,35 @@ import sqlalchemy
from ..core import app, entities
from . import entities as pipeline_entities
from ..entity.persistence import pipeline as persistence_pipeline
from . import stagemgr, stage
from . import stage
from ..platform.types import message as platform_message, events as platform_events
from ..plugin import events
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
from .preproc import preproc
from .ratelimit import ratelimit
from .msgtrun import msgtrun
class StageInstContainer():
"""阶段实例容器
"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class RuntimePipeline:
"""运行时流水线"""
@@ -20,10 +45,10 @@ class RuntimePipeline:
pipeline_entity: persistence_pipeline.LegacyPipeline
"""流水线实体"""
stage_containers: list[stagemgr.StageInstContainer]
stage_containers: list[StageInstContainer]
"""阶段实例容器"""
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]):
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]):
self.ap = ap
self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers
@@ -47,10 +72,18 @@ class RuntimePipeline:
*result.user_notice
)
await self.ap.platform_mgr.send(
query.message_event,
result.user_notice,
query.adapter
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
result.user_notice.insert(
0,
platform_message.At(
query.message_event.sender.id
)
)
await query.adapter.reply_message(
message_source=query.message_event,
message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin']
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
@@ -195,12 +228,15 @@ class PipelineManager:
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
# initialize stage containers according to pipeline_entity.stages
stage_containers = []
stage_containers: list[StageInstContainer] = []
for stage_name in pipeline_entity.stages:
stage_containers.append(stagemgr.StageInstContainer(
stage_containers.append(StageInstContainer(
inst_name=stage_name,
inst=self.stage_dict[stage_name](self.ap)
))
for stage_container in stage_containers:
await stage_container.inst.initialize(pipeline_entity.config)
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
self.pipelines.append(runtime_pipeline)

View File

@@ -28,23 +28,23 @@ class QueryPool:
async def add_query(
self,
bot_uuid: str,
launcher_type: entities.LauncherTypes,
launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str],
message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain,
adapter: msadapter.MessagePlatformAdapter,
pipeline_uuid: str
) -> entities.Query:
async with self.condition:
query = entities.Query(
bot_uuid=bot_uuid,
query_id=self.query_id_counter,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain,
pipeline_uuid=pipeline_uuid,
resp_messages=[],
resp_message_chain=[],
adapter=adapter

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import datetime
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...provider import entities as llm_entities
from ...plugin import events
@@ -33,16 +33,16 @@ class PreProcessor(stage.PipelineStage):
"""
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
conversation = await self.ap.sess_mgr.get_conversation(query, session)
# 设置query
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
query.use_model = conversation.use_model
query.use_llm_model = conversation.use_llm_model
query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
query.variables = {
"session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}",
@@ -50,8 +50,9 @@ class PreProcessor(stage.PipelineStage):
"msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()),
}
# 检查vision是否启用没启用就删除所有图片
if not self.ap.provider_cfg.data['enable-vision'] or (self.ap.provider_cfg.data['runner'] == 'local-agent' and not query.use_model.vision_supported):
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
@@ -69,7 +70,7 @@ class PreProcessor(stage.PipelineStage):
)
plain_text += me.text
elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported):
if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_base64(me.base64)

View File

@@ -9,7 +9,9 @@ import json
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities, runnermgr
from ....provider import entities as llm_entities
from ....provider import runner as runner_module
from ....provider.runners import localagent, difysvapi, dashscopeapi
from ....plugin import events
from ....platform.types import message as platform_message
@@ -56,12 +58,6 @@ class ChatMessageHandler(handler.MessageHandler):
)
else:
if not self.ap.provider_cfg.data['enable-chat']:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
if event_ctx.event.alter is not None:
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
query.user_message.content = event_ctx.event.alter
@@ -72,7 +68,12 @@ class ChatMessageHandler(handler.MessageHandler):
try:
runner = self.ap.runner_mgr.get_runner()
for r in runner_module.preregistered_runners:
if r.name == query.pipeline_config["ai"]["runner"]["runner"]:
runner = r(self.ap, query.pipeline_config)
break
else:
raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}")
async for result in runner.run(query):
query.resp_messages.append(result)
@@ -93,10 +94,12 @@ class ChatMessageHandler(handler.MessageHandler):
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}',
user_notice='请求失败' if hide_exception_info else f'{e}',
error_notice=f'{e}',
debug_notice=traceback.format_exc()
)

View File

@@ -4,7 +4,7 @@ from ...core import app, entities as core_entities
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -23,7 +23,7 @@ class Processor(stage.PipelineStage):
chat_handler: handler.MessageHandler
async def initialize(self):
async def initialize(self, pipeline_config: dict):
self.cmd_handler = command.CommandHandler(self.ap)
self.chat_handler = chat.ChatMessageHandler(self.ap)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import typing
from .. import entities, stagemgr, stage
from .. import entities, stage
from . import algo
from .algos import fixedwin
from ...core import entities as core_entities
@@ -18,7 +18,7 @@ class RateLimit(stage.PipelineStage):
algo: algo.ReteLimitAlgo
async def initialize(self):
async def initialize(self, pipeline_config: dict):
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']

View File

@@ -5,8 +5,10 @@ import asyncio
from ...core import app
from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -19,8 +21,8 @@ class SendResponseBackStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
random_range = (self.ap.platform_cfg.data['force-delay']['min'], self.ap.platform_cfg.data['force-delay']['max'])
random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max'])
random_delay = random.uniform(*random_range)
@@ -31,10 +33,20 @@ class SendResponseBackStage(stage.PipelineStage):
await asyncio.sleep(random_delay)
await self.ap.platform_mgr.send(
query.message_event,
query.resp_message_chain[-1],
adapter=query.adapter
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
query.resp_message_chain[-1].insert(
0,
platform_message.At(
query.message_event.sender.id
)
)
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
await query.adapter.reply_message(
message_source=query.message_event,
message=query.resp_message_chain[-1],
quote_origin=quote_origin
)
return entities.StageProcessResult(

View File

@@ -5,7 +5,7 @@ from ...core import app
from . import entities as rule_entities, rule
from .rules import atbot, prefix, regexp, random
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -20,7 +20,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
rule_matchers: list[rule.GroupRespondRule]
"""检查器实例"""
async def initialize(self):
async def initialize(self, pipeline_config: dict):
"""初始化检查器
"""

View File

@@ -28,7 +28,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
async def initialize(self, pipeline_config: dict):
"""初始化
"""
pass

View File

@@ -1,71 +0,0 @@
from __future__ import annotations
from ..core import app
from . import stage
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
from .preproc import preproc
from .ratelimit import ratelimit
from .msgtrun import msgtrun
# 请求处理阶段顺序
stage_order = [
"GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", # 预处理器
"ConversationMessageTruncator", # 会话消息截断器
"RequireRateLimitOccupancy", # 请求速率限制占用
"MessageProcessor", # 处理器
"ReleaseRateLimitOccupancy", # 释放速率限制占用
"PostContentFilterStage", # 内容过滤后置阶段
"ResponseWrapper", # 响应包装器
"LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
]
class StageInstContainer():
"""阶段实例容器
"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class StageManager:
ap: app.Application
stage_containers: list[StageInstContainer]
def __init__(self, ap: app.Application):
self.ap = ap
self.stage_containers = []
async def initialize(self):
"""初始化
"""
for name, cls in stage.preregistered_stages.items():
self.stage_containers.append(StageInstContainer(
inst_name=name,
inst=cls(self.ap)
))
for stage_containers in self.stage_containers:
await stage_containers.inst.initialize()
# 按照 stage_order 排序
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))

View File

@@ -5,7 +5,7 @@ import typing
from ...core import app, entities as core_entities
from .. import entities
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...plugin import events
@@ -22,7 +22,7 @@ class ResponseWrapper(stage.PipelineStage):
- resp_message_chain
"""
async def initialize(self):
async def initialize(self, pipeline_config: dict):
pass
async def process(
@@ -110,7 +110,7 @@ class ResponseWrapper(stage.PipelineStage):
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
if query.pipeline_config['output']['misc']['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(

View File

@@ -55,25 +55,25 @@ class RuntimeBot:
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter):
await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter,
pipeline_uuid=self.bot_entity.use_pipeline_uuid
)
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter):
await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter,
pipeline_uuid=self.bot_entity.use_pipeline_uuid
)
self.adapter.register_listener(
@@ -113,14 +113,16 @@ class RuntimeBot:
async def shutdown(self):
await self.adapter.kill()
self.ap.task_mgr.cancel_task(self.task_wrapper.id)
# 控制QQ消息输入输出的类
class PlatformManager:
# adapter: msadapter.MessageSourceAdapter = None
adapters: list[msadapter.MessagePlatformAdapter] = []
adapters: list[msadapter.MessagePlatformAdapter] = [] # deprecated
message_platform_adapter_components: list[engine.Component] = []
message_platform_adapter_components: list[engine.Component] = [] # deprecated
# ====== 4.0 ======
ap: app.Application = None
@@ -215,50 +217,36 @@ class PlatformManager:
return None
async def write_back_config(self, adapter_name: str, adapter_inst: msadapter.MessagePlatformAdapter, config: dict):
index = -2
# index = -2
for i, adapter in enumerate(self.adapters):
if adapter == adapter_inst:
index = i
break
# for i, adapter in enumerate(self.adapters):
# if adapter == adapter_inst:
# index = i
# break
if index == -2:
raise Exception('平台适配器未找到')
# if index == -2:
# raise Exception('平台适配器未找到')
# 只修改启用的适配器
real_index = -1
# # 只修改启用的适配器
# real_index = -1
for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
if adapter['enable']:
index -= 1
if index == -1:
real_index = i
break
# for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
# if adapter['enable']:
# index -= 1
# if index == -1:
# real_index = i
# break
new_cfg = {
'adapter': adapter_name,
'enable': True,
**config
}
self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
await self.ap.platform_cfg.dump_config()
# new_cfg = {
# 'adapter': adapter_name,
# 'enable': True,
# **config
# }
# self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
# await self.ap.platform_cfg.dump_config()
async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessagePlatformAdapter):
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage):
msg.insert(
0,
platform_message.At(
event.sender.id
)
)
await adapter.reply_message(
event,
msg,
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False
)
# TODO implement this
pass
async def run(self):
# This method will only be called when the application launching
@@ -270,4 +258,4 @@ class PlatformManager:
for bot in self.bots:
if bot.enable:
await bot.shutdown()
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)

View File

@@ -238,4 +238,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
await self.bot._server_app.run_task(**self.config)
async def kill(self) -> bool:
# Current issue: existing connection will not be closed
# self.should_shutdown = True
return False

View File

@@ -222,10 +222,10 @@ class EventContext:
Args:
message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
"""
await self.host.ap.platform_mgr.send(
event=self.event.query.message_event,
msg=message_chain,
adapter=self.event.query.adapter,
# TODO 添加 at_sender 和 quote_origin 参数
await self.event.query.adapter.reply_message(
message_source=self.event.query.message_event,
message=message_chain
)
async def send_message(

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import typing
import sqlalchemy
import pydantic.v1 as pydantic
from . import entities, requester
from ...core import app
@@ -16,23 +17,6 @@ from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcm
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
class RuntimeLLMModel:
"""运行时模型"""
model_entity: persistence_model.LLMModel
"""模型数据"""
token_mgr: token.TokenManager
"""api key管理器"""
requester: requester.LLMAPIRequester
"""请求器实例"""
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
class ModelManager:
"""模型管理器"""
@@ -47,7 +31,7 @@ class ModelManager:
ap: app.Application
llm_models: list[RuntimeLLMModel]
llm_models: list[requester.RuntimeLLMModel]
requester_components: list[engine.Component]
@@ -99,16 +83,20 @@ class ModelManager:
elif isinstance(model_info, dict):
model_info = persistence_model.LLMModel(**model_info)
runtime_llm_model = RuntimeLLMModel(
requester_inst = self.requester_dict[model_info.requester](
ap=self.ap,
config=model_info.requester_config
)
await requester_inst.initialize()
runtime_llm_model = requester.RuntimeLLMModel(
model_entity=model_info,
token_mgr=token.TokenManager(
name=model_info.uuid,
tokens=model_info.api_keys,
),
requester=self.requester_dict[model_info.requester](
ap=self.ap,
config=model_info.requester_config
)
requester=requester_inst
)
self.llm_models.append(runtime_llm_model)

View File

@@ -6,8 +6,27 @@ import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from . import entities as modelmgr_entities
from ..tools import entities as tools_entities
from ...entity.persistence import model as persistence_model
from . import token
class RuntimeLLMModel:
"""运行时模型"""
model_entity: persistence_model.LLMModel
"""模型数据"""
token_mgr: token.TokenManager
"""api key管理器"""
requester: LLMAPIRequester
"""请求器实例"""
def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester):
self.model_entity = model_entity
self.token_mgr = token_mgr
self.requester = requester
class LLMAPIRequester(metaclass=abc.ABCMeta):
@@ -31,21 +50,11 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def initialize(self):
pass
async def preprocess(
self,
query: core_entities.Query,
):
"""预处理
在这里处理特定API对Query对象的兼容性问题。
"""
pass
@abc.abstractmethod
async def call(
async def invoke_llm(
self,
query: core_entities.Query,
model: modelmgr_entities.LLMModelInfo,
model: RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
@@ -53,7 +62,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
"""调用API
Args:
model (modelmgr_entities.LLMModelInfo): 使用的模型信息
model (RuntimeLLMModel): 使用的模型信息
messages (typing.List[llm_entities.Message]): 消息对象列表
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.

View File

@@ -24,16 +24,16 @@ class AnthropicMessages(requester.LLMAPIRequester):
client: anthropic.AsyncAnthropic
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.anthropic.com/v1',
'base_url': 'https://api.anthropic.com/v1',
'timeout': 120,
}
async def initialize(self):
httpx_client = anthropic._base_client.AsyncHttpxClientWrapper(
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
base_url=self.requester_cfg['base_url'],
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=typing.cast(httpx.Timeout, self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout']),
timeout=typing.cast(httpx.Timeout, self.requester_cfg['timeout']),
limits=anthropic._constants.DEFAULT_CONNECTION_LIMITS,
follow_redirects=True,
trust_env=True,
@@ -44,17 +44,18 @@ class AnthropicMessages(requester.LLMAPIRequester):
http_client=httpx_client,
)
async def call(
async def invoke_llm(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = model.name if model.model_name is None else model.model_name
args = extra_args.copy()
args["model"] = model.model_entity.name
# 处理消息

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: Anthropic
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class BailianChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 阿里云百炼
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -26,7 +26,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
"base-url": "https://api.openai.com/v1",
"base_url": "https://api.openai.com/v1",
"timeout": 120,
}
@@ -34,7 +34,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
self.client = openai.AsyncClient(
api_key="",
base_url=self.requester_cfg["base-url"],
base_url=self.requester_cfg["base_url"],
timeout=self.requester_cfg["timeout"],
http_client=httpx.AsyncClient(
trust_env=True, timeout=self.requester_cfg["timeout"]
@@ -65,16 +65,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg["args"].copy()
args["model"] = (
use_model.name if use_model.model_name is None else use_model.model_name
)
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
@@ -104,10 +102,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
return message
async def call(
async def invoke_llm(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: OpenAI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -13,7 +13,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Deepseek ChatCompletion API 请求器"""
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.deepseek.com',
'base_url': 'https://api.deepseek.com',
'timeout': 120,
}
@@ -21,14 +21,14 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 深度求索
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -18,7 +18,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Gitee AI ChatCompletions API 请求器"""
default_config: dict[str, typing.Any] = {
'base-url': 'https://ai.gitee.com/v1',
'base_url': 'https://ai.gitee.com/v1',
'timeout': 120,
}
@@ -26,14 +26,14 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: Gitee AI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'http://127.0.0.1:1234/v1',
'base_url': 'http://127.0.0.1:1234/v1',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: LM Studio
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -15,7 +15,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Moonshot ChatCompletion API 请求器"""
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.moonshot.cn/v1',
'base_url': 'https://api.moonshot.cn/v1',
'timeout': 120,
}
@@ -23,14 +23,14 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
args = extra_args.copy()
args["model"] = use_model.model_entity.name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 月之暗面
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -22,35 +22,38 @@ REQUESTER_NAME: str = "ollama-chat"
class OllamaChatCompletions(requester.LLMAPIRequester):
"""Ollama平台 ChatCompletion API请求器"""
client: ollama.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'http://127.0.0.1:11434',
'timeout': 120,
"base_url": "http://127.0.0.1:11434",
"timeout": 120,
}
async def initialize(self):
os.environ['OLLAMA_HOST'] = self.requester_cfg['base-url']
self.client = ollama.AsyncClient(
timeout=self.requester_cfg['timeout']
)
os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"]
self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"])
async def _req(self,
args: dict,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
return await self.client.chat(
**args
)
async def _req(
self,
args: dict,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
return await self.client.chat(**args)
async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo,
user_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {}) -> llm_entities.Message:
args: Any = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
async def _closure(
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
user_funcs: list[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
args = extra_args.copy()
args["model"] = use_model.model_entity.name
messages: list[dict] = req_messages.copy()
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
if "content" in msg and isinstance(msg["content"], list):
text_content: list = []
image_urls: list = []
for me in msg["content"]:
@@ -58,12 +61,16 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
text_content.append(me["text"])
elif me["type"] == "image_base64":
image_urls.append(me["image_base64"])
msg["content"] = "\n".join(text_content)
msg["images"] = [url.split(',')[1] for url in image_urls]
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
for tool_call in msg['tool_calls']:
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
msg["images"] = [url.split(",")[1] for url in image_urls]
if (
"tool_calls" in msg
): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
for tool_call in msg["tool_calls"]:
tool_call["function"]["arguments"] = json.loads(
tool_call["function"]["arguments"]
)
args["messages"] = messages
args["tools"] = []
@@ -77,8 +84,8 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
return message
async def _make_msg(
self,
chat_completions: ollama.ChatResponse) -> llm_entities.Message:
self, chat_completions: ollama.ChatResponse
) -> llm_entities.Message:
message: ollama.Message = chat_completions.message
if message is None:
raise ValueError("chat_completions must contain a 'message' field")
@@ -86,43 +93,51 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
ret_msg: llm_entities.Message = None
if message.content is not None:
ret_msg = llm_entities.Message(
role="assistant",
content=message.content
)
ret_msg = llm_entities.Message(role="assistant", content=message.content)
if message.tool_calls is not None and len(message.tool_calls) > 0:
tool_calls: list[llm_entities.ToolCall] = []
for tool_call in message.tool_calls:
tool_calls.append(llm_entities.ToolCall(
id=uuid.uuid4().hex,
type="function",
function=llm_entities.FunctionCall(
name=tool_call.function.name,
arguments=json.dumps(tool_call.function.arguments)
tool_calls.append(
llm_entities.ToolCall(
id=uuid.uuid4().hex,
type="function",
function=llm_entities.FunctionCall(
name=tool_call.function.name,
arguments=json.dumps(tool_call.function.arguments),
),
)
))
)
ret_msg.tool_calls = tool_calls
return ret_msg
async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
async def invoke_llm(
self,
query: core_entities.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
req_messages: list = []
for m in messages:
msg_dict: dict = m.dict(exclude_none=True)
content: Any = msg_dict.get("content")
if isinstance(content, list):
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
if all(
isinstance(part, dict) and part.get("type") == "text"
for part in content
):
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(query, req_messages, model, funcs, extra_args)
return await self._closure(
query=query,
req_messages=req_messages,
use_model=model,
use_funcs=funcs,
extra_args=extra_args,
)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
raise errors.RequesterError("请求超时")

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: Ollama
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.siliconflow.cn/v1',
'base_url': 'https://api.siliconflow.cn/v1',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 硅基流动
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://ark.cn-beijing.volces.com/api/v3',
'base_url': 'https://ark.cn-beijing.volces.com/api/v3',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 火山方舟
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class XaiChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://api.x.ai/v1',
'base_url': 'https://api.x.ai/v1',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: xAI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -14,6 +14,6 @@ class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions):
client: openai.AsyncClient
default_config: dict[str, typing.Any] = {
'base-url': 'https://open.bigmodel.cn/api/paas/v4',
'base_url': 'https://open.bigmodel.cn/api/paas/v4',
'timeout': 120,
}

View File

@@ -7,7 +7,7 @@ metadata:
zh_CN: 智谱 AI
spec:
config:
- name: base-url
- name: base_url
label:
en_US: Base URL
zh_CN: 基础 URL

View File

@@ -27,11 +27,11 @@ class RequestRunner(abc.ABC):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
pipeline_config: dict
async def initialize(self):
pass
def __init__(self, ap: app.Application, pipeline_config: dict):
self.ap = ap
self.pipeline_config = pipeline_config
@abc.abstractmethod
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:

View File

@@ -1,30 +0,0 @@
from __future__ import annotations
from . import runner
from ..core import app
from .runners import localagent
from .runners import difysvapi
from .runners import dashscopeapi
class RunnerManager:
ap: app.Application
using_runner: runner.RequestRunner
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
for r in runner.preregistered_runners:
if r.name == self.ap.provider_cfg.data['runner']:
self.using_runner = r(self.ap)
await self.using_runner.initialize()
break
else:
raise ValueError(f"未找到请求运行器: {self.ap.provider_cfg.data['runner']}")
def get_runner(self) -> runner.RequestRunner:
return self.using_runner

View File

@@ -8,7 +8,7 @@ import re
import dashscope
from .. import runner
from ...core import entities as core_entities
from ...core import app, entities as core_entities
from .. import entities as llm_entities
from ...utils import image
@@ -29,12 +29,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
app_id: str # 应用ID
api_key: str # API Key
references_quote: str # 引用资料提示当展示回答来源功能开启时这个变量会作为引用资料名前的提示可在provider.json中配置
biz_params: dict = {} # 工作流应用参数(仅在工作流应用中生效)
async def initialize(self):
def __init__(self, ap: app.Application, pipeline_config: dict):
"""初始化"""
self.ap = ap
self.pipeline_config = pipeline_config
valid_app_types = ["agent", "workflow"]
self.app_type = self.ap.provider_cfg.data["dashscope-app-api"]["app-type"]
self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"]
#检查配置文件中使用的应用类型是否支持
if (self.app_type not in valid_app_types):
raise DashscopeAPIError(
@@ -42,10 +44,9 @@ class DashScopeAPIRunner(runner.RequestRunner):
)
#初始化Dashscope 参数配置
self.app_id = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["app-id"]
self.api_key = self.ap.provider_cfg.data["dashscope-app-api"]["api-key"]
self.references_quote = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["references_quote"]
self.biz_params = self.ap.provider_cfg.data["dashscope-app-api"]["workflow"]["biz_params"]
self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"]
self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"]
self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"]
def _replace_references(self, text, references_dict):
"""阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料"""
@@ -169,7 +170,6 @@ class DashScopeAPIRunner(runner.RequestRunner):
plain_text, image_ids = await self._preprocess_user_message(query)
biz_params = {}
biz_params.update(self.biz_params)
biz_params.update(query.variables)
#发送对话请求
@@ -220,21 +220,19 @@ class DashScopeAPIRunner(runner.RequestRunner):
content=pending_content,
)
async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行"""
if self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "agent":
if self.app_type == "agent":
async for msg in self._agent_messages(query):
yield msg
elif self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "workflow":
elif self.app_type == "workflow":
async for msg in self._workflow_messages(query):
yield msg
else:
raise DashscopeAPIError(
f"不支持的 Dashscope 应用类型: {self.ap.provider_cfg.data['dashscope-app-api']['app-type']}"
f"不支持的 Dashscope 应用类型: {self.app_type}"
)

View File

@@ -10,7 +10,7 @@ import datetime
import aiohttp
from .. import runner
from ...core import entities as core_entities
from ...core import app, entities as core_entities
from .. import entities as llm_entities
from ...utils import image
@@ -23,24 +23,24 @@ class DifyServiceAPIRunner(runner.RequestRunner):
dify_client: client.AsyncDifyServiceClient
async def initialize(self):
"""初始化"""
def __init__(self, ap: app.Application, pipeline_config: dict):
self.ap = ap
self.pipeline_config = pipeline_config
valid_app_types = ["chat", "agent", "workflow"]
if (
self.ap.provider_cfg.data["dify-service-api"]["app-type"]
self.pipeline_config["ai"]["dify-service-api"]["app-type"]
not in valid_app_types
):
raise errors.DifyAPIError(
f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}"
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
)
api_key = self.ap.provider_cfg.data["dify-service-api"][
self.ap.provider_cfg.data["dify-service-api"]["app-type"]
]["api-key"]
api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"]
self.dify_client = client.AsyncDifyServiceClient(
api_key=api_key,
base_url=self.ap.provider_cfg.data["dify-service-api"]["base-url"],
base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"],
)
def _try_convert_thinking(self, resp_text: str) -> str:
@@ -48,13 +48,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if not resp_text.startswith("<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>"):
return resp_text
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "original":
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original":
return resp_text
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "remove":
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove":
return re.sub(r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>.*?</details>', '', resp_text, flags=re.DOTALL)
if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "plain":
if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain":
pattern = r'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> <summary> Thinking... </summary>(.*?)</details>'
thinking_text = re.search(pattern, resp_text, flags=re.DOTALL)
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
@@ -121,7 +121,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
conversation_id=cov_id,
files=files,
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
):
self.ap.logger.debug("dify-chat-chunk: " + str(chunk))
@@ -177,7 +177,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
response_mode="streaming",
conversation_id=cov_id,
files=files,
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
):
self.ap.logger.debug("dify-agent-chunk: " + str(chunk))
@@ -264,7 +264,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
inputs=inputs,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
files=files,
timeout=self.ap.provider_cfg.data["dify-service-api"]["workflow"]["timeout"],
timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"],
):
self.ap.logger.debug("dify-workflow-chunk: " + str(chunk))
if chunk["event"] in ignored_events:
@@ -301,11 +301,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
msg = llm_entities.Message(
role="assistant",
content=chunk["data"]["outputs"][
self.ap.provider_cfg.data["dify-service-api"]["workflow"][
"output-key"
]
],
content=chunk["data"]["outputs"]["summary"],
)
yield msg
@@ -314,16 +310,16 @@ class DifyServiceAPIRunner(runner.RequestRunner):
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
if self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "chat":
if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat":
async for msg in self._chat_messages(query):
yield msg
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "agent":
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent":
async for msg in self._agent_chat_messages(query):
yield msg
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "workflow":
elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow":
async for msg in self._workflow_messages(query):
yield msg
else:
raise errors.DifyAPIError(
f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}"
f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}"
)

View File

@@ -16,14 +16,12 @@ class LocalAgentRunner(runner.RequestRunner):
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
yield msg
@@ -61,7 +59,7 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)
msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs)
yield msg

View File

@@ -41,7 +41,7 @@ class SessionManager:
self.session_list.append(session)
return session
async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation:
async def get_conversation(self, query: core_entities.Query, session: core_entities.Session) -> core_entities.Conversation:
"""获取对话或创建对话"""
if not session.conversations:
@@ -51,7 +51,9 @@ class SessionManager:
conversation = core_entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
use_llm_model=await self.ap.model_mgr.get_model_by_uuid(
query.pipeline_config['ai']['local-agent']['model']
),
use_funcs=await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
),