mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: switch Query to langbot-plugin definition
This commit is contained in:
10
main.py
10
main.py
@@ -47,13 +47,13 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
|||||||
if not args.skip_plugin_deps_check:
|
if not args.skip_plugin_deps_check:
|
||||||
await deps.precheck_plugin_deps()
|
await deps.precheck_plugin_deps()
|
||||||
|
|
||||||
# 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
# # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
||||||
import pydantic.version
|
# import pydantic.version
|
||||||
|
|
||||||
if pydantic.version.VERSION < '2.0':
|
# if pydantic.version.VERSION < '2.0':
|
||||||
import pydantic
|
# import pydantic
|
||||||
|
|
||||||
sys.modules['pydantic.v1'] = pydantic
|
# sys.modules['pydantic.v1'] = pydantic
|
||||||
|
|
||||||
# 检查配置文件
|
# 检查配置文件
|
||||||
|
|
||||||
|
|||||||
@@ -35,15 +35,6 @@ class SystemRouterGroup(group.RouterGroup):
|
|||||||
|
|
||||||
return self.success(data=task.to_dict())
|
return self.success(data=task.to_dict())
|
||||||
|
|
||||||
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
|
||||||
async def _() -> str:
|
|
||||||
json_data = await quart.request.json
|
|
||||||
|
|
||||||
scope = json_data.get('scope')
|
|
||||||
|
|
||||||
await self.ap.reload(scope=scope)
|
|
||||||
return self.success()
|
|
||||||
|
|
||||||
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
if not constants.debug_mode:
|
if not constants.debug_mode:
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import app
|
||||||
from . import entities, operator, errors
|
from . import entities, operator, errors
|
||||||
from ..utils import importutil
|
from ..utils import importutil
|
||||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
# 引入所有算子以便注册
|
# 引入所有算子以便注册
|
||||||
from . import operators
|
from . import operators
|
||||||
@@ -90,7 +91,7 @@ class CommandManager:
|
|||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
command_text: str,
|
command_text: str,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
session: provider_session.Session,
|
session: provider_session.Session,
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
"""执行命令"""
|
"""执行命令"""
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
|
|
||||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
from ..core import entities as core_entities
|
|
||||||
from . import errors
|
from . import errors
|
||||||
from ..platform.types import message as platform_message
|
from ..platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class CommandReturn(pydantic.BaseModel):
|
class CommandReturn(pydantic.BaseModel):
|
||||||
@@ -35,7 +35,7 @@ class CommandReturn(pydantic.BaseModel):
|
|||||||
class ExecuteContext(pydantic.BaseModel):
|
class ExecuteContext(pydantic.BaseModel):
|
||||||
"""单次命令执行上下文"""
|
"""单次命令执行上下文"""
|
||||||
|
|
||||||
query: core_entities.Query
|
query: pipeline_query.Query
|
||||||
"""本次消息的请求对象"""
|
"""本次消息的请求对象"""
|
||||||
|
|
||||||
session: provider_session.Session
|
session: provider_session.Session
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from ..platform import botmgr as im_mgr
|
from ..platform import botmgr as im_mgr
|
||||||
@@ -183,59 +182,3 @@ class Application:
|
|||||||
""".strip()
|
""".strip()
|
||||||
for line in tips.split('\n'):
|
for line in tips.split('\n'):
|
||||||
self.logger.info(line)
|
self.logger.info(line)
|
||||||
|
|
||||||
async def reload(
|
|
||||||
self,
|
|
||||||
scope: core_entities.LifecycleControlScope,
|
|
||||||
):
|
|
||||||
match scope:
|
|
||||||
case core_entities.LifecycleControlScope.PLATFORM.value:
|
|
||||||
self.logger.info('执行热重载 scope=' + scope)
|
|
||||||
await self.platform_mgr.shutdown()
|
|
||||||
|
|
||||||
self.platform_mgr = im_mgr.PlatformManager(self)
|
|
||||||
|
|
||||||
await self.platform_mgr.initialize()
|
|
||||||
|
|
||||||
self.task_mgr.create_task(
|
|
||||||
self.platform_mgr.run(),
|
|
||||||
name='platform-manager',
|
|
||||||
scopes=[
|
|
||||||
core_entities.LifecycleControlScope.APPLICATION,
|
|
||||||
core_entities.LifecycleControlScope.PLATFORM,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
case core_entities.LifecycleControlScope.PLUGIN.value:
|
|
||||||
self.logger.info('执行热重载 scope=' + scope)
|
|
||||||
await self.plugin_mgr.destroy_plugins()
|
|
||||||
|
|
||||||
# 删除 sys.module 中所有的 plugins/* 下的模块
|
|
||||||
for mod in list(sys.modules.keys()):
|
|
||||||
if mod.startswith('plugins.'):
|
|
||||||
del sys.modules[mod]
|
|
||||||
|
|
||||||
self.plugin_mgr = plugin_mgr.PluginManager(self)
|
|
||||||
await self.plugin_mgr.initialize()
|
|
||||||
|
|
||||||
await self.plugin_mgr.initialize_plugins()
|
|
||||||
|
|
||||||
await self.plugin_mgr.load_plugins()
|
|
||||||
await self.plugin_mgr.initialize_plugins()
|
|
||||||
case core_entities.LifecycleControlScope.PROVIDER.value:
|
|
||||||
self.logger.info('执行热重载 scope=' + scope)
|
|
||||||
|
|
||||||
await self.tool_mgr.shutdown()
|
|
||||||
|
|
||||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(self)
|
|
||||||
await llm_model_mgr_inst.initialize()
|
|
||||||
self.model_mgr = llm_model_mgr_inst
|
|
||||||
|
|
||||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(self)
|
|
||||||
await llm_session_mgr_inst.initialize()
|
|
||||||
self.sess_mgr = llm_session_mgr_inst
|
|
||||||
|
|
||||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
|
|
||||||
await llm_tool_mgr_inst.initialize()
|
|
||||||
self.tool_mgr = llm_tool_mgr_inst
|
|
||||||
case _:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import enum
|
import enum
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
|
|
||||||
from ..provider import entities as llm_entities
|
from ..provider import entities as llm_entities
|
||||||
from ..platform import adapter as msadapter
|
from ..platform import adapter as msadapter
|
||||||
@@ -20,23 +20,13 @@ class LifecycleControlScope(enum.Enum):
|
|||||||
PROVIDER = 'provider'
|
PROVIDER = 'provider'
|
||||||
|
|
||||||
|
|
||||||
class LauncherTypes(enum.Enum):
|
|
||||||
"""一个请求的发起者类型"""
|
|
||||||
|
|
||||||
PERSON = 'person'
|
|
||||||
"""私聊"""
|
|
||||||
|
|
||||||
GROUP = 'group'
|
|
||||||
"""群聊"""
|
|
||||||
|
|
||||||
|
|
||||||
class Query(pydantic.BaseModel):
|
class Query(pydantic.BaseModel):
|
||||||
"""一次请求的信息封装"""
|
"""一次请求的信息封装"""
|
||||||
|
|
||||||
query_id: int
|
query_id: int
|
||||||
"""请求ID,添加进请求池时生成"""
|
"""请求ID,添加进请求池时生成"""
|
||||||
|
|
||||||
launcher_type: LauncherTypes
|
launcher_type: provider_session.LauncherTypes
|
||||||
"""会话类型,platform处理阶段设置"""
|
"""会话类型,platform处理阶段设置"""
|
||||||
|
|
||||||
launcher_id: typing.Union[int, str]
|
launcher_id: typing.Union[int, str]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('BanSessionCheckStage')
|
@stage.stage_class('BanSessionCheckStage')
|
||||||
@@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
|||||||
async def initialize(self, pipeline_config: dict):
|
async def initialize(self, pipeline_config: dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
found = False
|
found = False
|
||||||
|
|
||||||
mode = query.pipeline_config['trigger']['access-control']['mode']
|
mode = query.pipeline_config['trigger']['access-control']['mode']
|
||||||
|
|||||||
@@ -3,12 +3,11 @@ from __future__ import annotations
|
|||||||
from ...core import app
|
from ...core import app
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
|
||||||
from . import filter as filter_model, entities as filter_entities
|
from . import filter as filter_model, entities as filter_entities
|
||||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||||
from ...platform.types import message as platform_message
|
from ...platform.types import message as platform_message
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
from . import filters
|
from . import filters
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(filters)
|
importutil.import_modules_in_pkg(filters)
|
||||||
@@ -58,7 +57,7 @@ class ContentFilterStage(stage.PipelineStage):
|
|||||||
async def _pre_process(
|
async def _pre_process(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
"""请求llm前处理消息
|
"""请求llm前处理消息
|
||||||
只要有一个不通过就不放行,只放行 PASS 的消息
|
只要有一个不通过就不放行,只放行 PASS 的消息
|
||||||
@@ -93,7 +92,7 @@ class ContentFilterStage(stage.PipelineStage):
|
|||||||
async def _post_process(
|
async def _post_process(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
"""请求llm后处理响应
|
"""请求llm后处理响应
|
||||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||||
@@ -123,7 +122,7 @@ class ContentFilterStage(stage.PipelineStage):
|
|||||||
|
|
||||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
if stage_inst_name == 'PreContentFilterStage':
|
if stage_inst_name == 'PreContentFilterStage':
|
||||||
contain_non_text = False
|
contain_non_text = False
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
class ResultLevel(enum.Enum):
|
class ResultLevel(enum.Enum):
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
from . import entities
|
from . import entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
||||||
"""处理消息
|
"""处理消息
|
||||||
|
|
||||||
分为前后阶段,具体取决于 enable_stages 的值。
|
分为前后阶段,具体取决于 enable_stages 的值。
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ import aiohttp
|
|||||||
|
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from .. import filter as filter_model
|
from .. import filter as filter_model
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
||||||
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
||||||
@@ -27,7 +26,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
|||||||
) as resp:
|
) as resp:
|
||||||
return (await resp.json())['access_token']
|
return (await resp.json())['access_token']
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import re
|
|||||||
|
|
||||||
from .. import filter as filter_model
|
from .. import filter as filter_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@filter_model.filter_class('ban-word-filter')
|
@filter_model.filter_class('ban-word-filter')
|
||||||
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||||
found = False
|
found = False
|
||||||
|
|
||||||
for word in self.ap.sensitive_meta.data['words']:
|
for word in self.ap.sensitive_meta.data['words']:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import re
|
|||||||
|
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from .. import filter as filter_model
|
from .. import filter as filter_model
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@filter_model.filter_class('content-ignore')
|
@filter_model.filter_class('content-ignore')
|
||||||
@@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
|
|||||||
entities.EnableStage.PRE,
|
entities.EnableStage.PRE,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||||
if message.startswith(rule):
|
if message.startswith(rule):
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from ..core import app, entities
|
from ..core import app
|
||||||
|
from ..core import entities as core_entities
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class Controller:
|
class Controller:
|
||||||
@@ -22,11 +25,11 @@ class Controller:
|
|||||||
"""事件处理循环"""
|
"""事件处理循环"""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
selected_query: entities.Query = None
|
selected_query: pipeline_query.Query = None
|
||||||
|
|
||||||
# 取请求
|
# 取请求
|
||||||
async with self.ap.query_pool:
|
async with self.ap.query_pool:
|
||||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
queries: list[pipeline_query.Query] = self.ap.query_pool.queries
|
||||||
|
|
||||||
for query in queries:
|
for query in queries:
|
||||||
session = await self.ap.sess_mgr.get_session(query)
|
session = await self.ap.sess_mgr.get_session(query)
|
||||||
@@ -46,7 +49,7 @@ class Controller:
|
|||||||
|
|
||||||
if selected_query:
|
if selected_query:
|
||||||
|
|
||||||
async def _process_query(selected_query: entities.Query):
|
async def _process_query(selected_query: pipeline_query.Query):
|
||||||
async with self.semaphore: # 总并发上限
|
async with self.semaphore: # 总并发上限
|
||||||
# find pipeline
|
# find pipeline
|
||||||
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
||||||
@@ -68,8 +71,8 @@ class Controller:
|
|||||||
kind='query',
|
kind='query',
|
||||||
name=f'query-{selected_query.query_id}',
|
name=f'query-{selected_query.query_id}',
|
||||||
scopes=[
|
scopes=[
|
||||||
entities.LifecycleControlScope.APPLICATION,
|
core_entities.LifecycleControlScope.APPLICATION,
|
||||||
entities.LifecycleControlScope.PLATFORM,
|
core_entities.LifecycleControlScope.PLATFORM,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|||||||
import enum
|
import enum
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
from ..platform.types import message as platform_message
|
from ..platform.types import message as platform_message
|
||||||
|
|
||||||
from ..core import entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class ResultType(enum.Enum):
|
class ResultType(enum.Enum):
|
||||||
@@ -20,7 +20,7 @@ class ResultType(enum.Enum):
|
|||||||
class StageProcessResult(pydantic.BaseModel):
|
class StageProcessResult(pydantic.BaseModel):
|
||||||
result_type: ResultType
|
result_type: ResultType
|
||||||
|
|
||||||
new_query: entities.Query
|
new_query: pipeline_query.Query
|
||||||
|
|
||||||
user_notice: typing.Optional[
|
user_notice: typing.Optional[
|
||||||
typing.Union[
|
typing.Union[
|
||||||
|
|||||||
@@ -5,10 +5,9 @@ import traceback
|
|||||||
|
|
||||||
from . import strategy
|
from . import strategy
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
|
||||||
from ...platform.types import message as platform_message
|
from ...platform.types import message as platform_message
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
from . import strategies
|
from . import strategies
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(strategies)
|
importutil.import_modules_in_pkg(strategies)
|
||||||
@@ -67,7 +66,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
|||||||
|
|
||||||
await self.strategy_impl.initialize()
|
await self.strategy_impl.initialize()
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
# 检查是否包含非 Plain 组件
|
# 检查是否包含非 Plain 组件
|
||||||
contains_non_plain = False
|
contains_non_plain = False
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
|
|
||||||
from .. import strategy as strategy_model
|
from .. import strategy as strategy_model
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ....platform.types import message as platform_message
|
|
||||||
|
|
||||||
|
from ....platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
|
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
|
||||||
Forward = platform_message.Forward
|
Forward = platform_message.Forward
|
||||||
@@ -13,7 +13,7 @@ Forward = platform_message.Forward
|
|||||||
|
|
||||||
@strategy_model.strategy_class('forward')
|
@strategy_model.strategy_class('forward')
|
||||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||||
display = ForwardMessageDiaplay(
|
display = ForwardMessageDiaplay(
|
||||||
title='群聊的聊天记录',
|
title='群聊的聊天记录',
|
||||||
brief='[聊天记录]',
|
brief='[聊天记录]',
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import functools
|
|||||||
from ....platform.types import message as platform_message
|
from ....platform.types import message as platform_message
|
||||||
|
|
||||||
from .. import strategy as strategy_model
|
from .. import strategy as strategy_model
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@strategy_model.strategy_class('image')
|
@strategy_model.strategy_class('image')
|
||||||
@@ -27,7 +27,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
|||||||
encoding='utf-8',
|
encoding='utf-8',
|
||||||
)
|
)
|
||||||
|
|
||||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||||
img_path = self.text_to_image(
|
img_path = self.text_to_image(
|
||||||
text_str=message,
|
text_str=message,
|
||||||
save_as='temp/{}.png'.format(int(time.time())),
|
save_as='temp/{}.png'.format(int(time.time())),
|
||||||
@@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
|||||||
text_str: str,
|
text_str: str,
|
||||||
save_as='temp.png',
|
save_as='temp.png',
|
||||||
width=800,
|
width=800,
|
||||||
query: core_entities.Query = None,
|
query: pipeline_query.Query = None,
|
||||||
):
|
):
|
||||||
text_str = text_str.replace('\t', ' ')
|
text_str = text_str.replace('\t', ' ')
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import typing
|
|||||||
|
|
||||||
|
|
||||||
from ...core import app
|
from ...core import app
|
||||||
from ...core import entities as core_entities
|
|
||||||
from ...platform.types import message as platform_message
|
from ...platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
||||||
@@ -49,7 +49,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||||
"""处理长文本
|
"""处理长文本
|
||||||
|
|
||||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
|
||||||
from . import truncator
|
from . import truncator
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
from . import truncators
|
from . import truncators
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(truncators)
|
importutil.import_modules_in_pkg(truncators)
|
||||||
@@ -29,7 +28,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'未知的截断器: {use_method}')
|
raise ValueError(f'未知的截断器: {use_method}')
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
query = await self.trun.truncate(query)
|
query = await self.trun.truncate(query)
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
from ...core import entities as core_entities, app
|
from ...core import app
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
preregistered_truncators: list[typing.Type[Truncator]] = []
|
preregistered_truncators: list[typing.Type[Truncator]] = []
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class Truncator(abc.ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||||
"""截断
|
"""截断
|
||||||
|
|
||||||
一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。
|
一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .. import truncator
|
from .. import truncator
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@truncator.truncator_class('round')
|
@truncator.truncator_class('round')
|
||||||
class RoundTruncator(truncator.Truncator):
|
class RoundTruncator(truncator.Truncator):
|
||||||
"""前文回合数阶段器"""
|
"""前文回合数阶段器"""
|
||||||
|
|
||||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||||
"""截断"""
|
"""截断"""
|
||||||
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import traceback
|
|||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
from ..core import app, entities
|
from ..core import app
|
||||||
from . import entities as pipeline_entities
|
from . import entities as pipeline_entities
|
||||||
from ..entity.persistence import pipeline as persistence_pipeline
|
from ..entity.persistence import pipeline as persistence_pipeline
|
||||||
from . import stage
|
from . import stage
|
||||||
@@ -13,6 +13,9 @@ from ..platform.types import message as platform_message, events as platform_eve
|
|||||||
from ..plugin import events
|
from ..plugin import events
|
||||||
from ..utils import importutil
|
from ..utils import importutil
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
resprule,
|
resprule,
|
||||||
bansess,
|
bansess,
|
||||||
@@ -75,11 +78,11 @@ class RuntimePipeline:
|
|||||||
self.pipeline_entity = pipeline_entity
|
self.pipeline_entity = pipeline_entity
|
||||||
self.stage_containers = stage_containers
|
self.stage_containers = stage_containers
|
||||||
|
|
||||||
async def run(self, query: entities.Query):
|
async def run(self, query: pipeline_query.Query):
|
||||||
query.pipeline_config = self.pipeline_entity.config
|
query.pipeline_config = self.pipeline_entity.config
|
||||||
await self.process_query(query)
|
await self.process_query(query)
|
||||||
|
|
||||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
|
||||||
"""检查输出"""
|
"""检查输出"""
|
||||||
if result.user_notice:
|
if result.user_notice:
|
||||||
# 处理str类型
|
# 处理str类型
|
||||||
@@ -109,7 +112,7 @@ class RuntimePipeline:
|
|||||||
async def _execute_from_stage(
|
async def _execute_from_stage(
|
||||||
self,
|
self,
|
||||||
stage_index: int,
|
stage_index: int,
|
||||||
query: entities.Query,
|
query: pipeline_query.Query,
|
||||||
):
|
):
|
||||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||||
|
|
||||||
@@ -169,13 +172,13 @@ class RuntimePipeline:
|
|||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
async def process_query(self, query: entities.Query):
|
async def process_query(self, query: pipeline_query.Query):
|
||||||
"""处理请求"""
|
"""处理请求"""
|
||||||
try:
|
try:
|
||||||
# ======== 触发 MessageReceived 事件 ========
|
# ======== 触发 MessageReceived 事件 ========
|
||||||
event_type = (
|
event_type = (
|
||||||
events.PersonMessageReceived
|
events.PersonMessageReceived
|
||||||
if query.launcher_type == entities.LauncherTypes.PERSON
|
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||||
else events.GroupMessageReceived
|
else events.GroupMessageReceived
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ..core import entities
|
|
||||||
from ..platform import adapter as msadapter
|
from ..platform import adapter as msadapter
|
||||||
from ..platform.types import message as platform_message
|
from ..platform.types import message as platform_message
|
||||||
from ..platform.types import events as platform_events
|
from ..platform.types import events as platform_events
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class QueryPool:
|
class QueryPool:
|
||||||
@@ -16,7 +17,7 @@ class QueryPool:
|
|||||||
|
|
||||||
pool_lock: asyncio.Lock
|
pool_lock: asyncio.Lock
|
||||||
|
|
||||||
queries: list[entities.Query]
|
queries: list[pipeline_query.Query]
|
||||||
|
|
||||||
condition: asyncio.Condition
|
condition: asyncio.Condition
|
||||||
|
|
||||||
@@ -29,16 +30,16 @@ class QueryPool:
|
|||||||
async def add_query(
|
async def add_query(
|
||||||
self,
|
self,
|
||||||
bot_uuid: str,
|
bot_uuid: str,
|
||||||
launcher_type: entities.LauncherTypes,
|
launcher_type: provider_session.LauncherTypes,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
sender_id: typing.Union[int, str],
|
sender_id: typing.Union[int, str],
|
||||||
message_event: platform_events.MessageEvent,
|
message_event: platform_events.MessageEvent,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
adapter: msadapter.MessagePlatformAdapter,
|
adapter: msadapter.MessagePlatformAdapter,
|
||||||
pipeline_uuid: typing.Optional[str] = None,
|
pipeline_uuid: typing.Optional[str] = None,
|
||||||
) -> entities.Query:
|
) -> pipeline_query.Query:
|
||||||
async with self.condition:
|
async with self.condition:
|
||||||
query = entities.Query(
|
query = pipeline_query.Query(
|
||||||
bot_uuid=bot_uuid,
|
bot_uuid=bot_uuid,
|
||||||
query_id=self.query_id_counter,
|
query_id=self.query_id_counter,
|
||||||
launcher_type=launcher_type,
|
launcher_type=launcher_type,
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
|
||||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||||
from ...plugin import events
|
from ...plugin import events
|
||||||
from ...platform.types import message as platform_message
|
from ...platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('PreProcessor')
|
@stage.stage_class('PreProcessor')
|
||||||
@@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage):
|
|||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
|
|
||||||
from ...core import app
|
from ...core import app
|
||||||
from ...core import entities as core_entities
|
|
||||||
from .. import entities
|
from .. import entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class MessageHandler(metaclass=abc.ABCMeta):
|
class MessageHandler(metaclass=abc.ABCMeta):
|
||||||
@@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def handle(
|
async def handle(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,15 @@ import traceback
|
|||||||
|
|
||||||
from .. import handler
|
from .. import handler
|
||||||
from ... import entities
|
from ... import entities
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ....provider import runner as runner_module
|
from ....provider import runner as runner_module
|
||||||
from ....plugin import events
|
from ....plugin import events
|
||||||
|
|
||||||
from ....platform.types import message as platform_message
|
from ....platform.types import message as platform_message
|
||||||
from ....utils import importutil
|
from ....utils import importutil
|
||||||
from ....provider import runners
|
from ....provider import runners
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(runners)
|
importutil.import_modules_in_pkg(runners)
|
||||||
|
|
||||||
@@ -20,7 +22,7 @@ importutil.import_modules_in_pkg(runners)
|
|||||||
class ChatMessageHandler(handler.MessageHandler):
|
class ChatMessageHandler(handler.MessageHandler):
|
||||||
async def handle(
|
async def handle(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
# 调API
|
# 调API
|
||||||
@@ -29,7 +31,7 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||||||
# 触发插件事件
|
# 触发插件事件
|
||||||
event_class = (
|
event_class = (
|
||||||
events.PersonNormalMessageReceived
|
events.PersonNormalMessageReceived
|
||||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||||
else events.GroupNormalMessageReceived
|
else events.GroupNormalMessageReceived
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ import typing
|
|||||||
|
|
||||||
from .. import handler
|
from .. import handler
|
||||||
from ... import entities
|
from ... import entities
|
||||||
from ....core import entities as core_entities
|
|
||||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||||
from ....plugin import events
|
from ....plugin import events
|
||||||
from ....platform.types import message as platform_message
|
from ....platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class CommandHandler(handler.MessageHandler):
|
class CommandHandler(handler.MessageHandler):
|
||||||
async def handle(
|
async def handle(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ class CommandHandler(handler.MessageHandler):
|
|||||||
|
|
||||||
event_class = (
|
event_class = (
|
||||||
events.PersonCommandSent
|
events.PersonCommandSent
|
||||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||||
else events.GroupCommandSent
|
else events.GroupCommandSent
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from ...core import entities as core_entities
|
|
||||||
from . import handler
|
from . import handler
|
||||||
from .handlers import chat, command
|
from .handlers import chat, command
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from .. import stage
|
from .. import stage
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('MessageProcessor')
|
@stage.stage_class('MessageProcessor')
|
||||||
@@ -30,7 +30,7 @@ class Processor(stage.PipelineStage):
|
|||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
||||||
@@ -33,7 +34,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def require_access(
|
async def require_access(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
launcher_type: str,
|
launcher_type: str,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -53,7 +54,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def release_access(
|
async def release_access(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
launcher_type: str,
|
launcher_type: str,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
from .. import algo
|
from .. import algo
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
# 固定窗口算法
|
# 固定窗口算法
|
||||||
@@ -32,7 +32,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
|||||||
|
|
||||||
async def require_access(
|
async def require_access(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
launcher_type: str,
|
launcher_type: str,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -91,7 +91,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
|||||||
|
|
||||||
async def release_access(
|
async def release_access(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
launcher_type: str,
|
launcher_type: str,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ import typing
|
|||||||
|
|
||||||
from .. import entities, stage
|
from .. import entities, stage
|
||||||
from . import algo
|
from . import algo
|
||||||
from ...core import entities as core_entities
|
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
from . import algos
|
from . import algos
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(algos)
|
importutil.import_modules_in_pkg(algos)
|
||||||
@@ -39,7 +40,7 @@ class RateLimit(stage.PipelineStage):
|
|||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> typing.Union[
|
) -> typing.Union[
|
||||||
entities.StageProcessResult,
|
entities.StageProcessResult,
|
||||||
|
|||||||
@@ -8,14 +8,14 @@ from ...platform.types import events as platform_events
|
|||||||
from ...platform.types import message as platform_message
|
from ...platform.types import message as platform_message
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('SendResponseBackStage')
|
@stage.stage_class('SendResponseBackStage')
|
||||||
class SendResponseBackStage(stage.PipelineStage):
|
class SendResponseBackStage(stage.PipelineStage):
|
||||||
"""发送响应消息"""
|
"""发送响应消息"""
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
|
|
||||||
random_range = (
|
random_range = (
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
|
|
||||||
from ...platform.types import message as platform_message
|
from ...platform.types import message as platform_message
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ from __future__ import annotations
|
|||||||
from . import rule
|
from . import rule
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
from . import rules
|
from . import rules
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(rules)
|
importutil.import_modules_in_pkg(rules)
|
||||||
@@ -32,7 +33,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
|||||||
await rule_inst.initialize()
|
await rule_inst.initialize()
|
||||||
self.rule_matchers.append(rule_inst)
|
self.rule_matchers.append(rule_inst)
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
if query.launcher_type.value != 'group': # 只处理群消息
|
if query.launcher_type.value != 'group': # 只处理群消息
|
||||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
from . import entities
|
from . import entities
|
||||||
|
|
||||||
from ...platform.types import message as platform_message
|
from ...platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
|
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
|
||||||
@@ -39,7 +40,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
"""判断消息是否匹配规则"""
|
"""判断消息是否匹配规则"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from .. import rule as rule_model
|
from .. import rule as rule_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ....platform.types import message as platform_message
|
from ....platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@rule_model.rule_class('at-bot')
|
@rule_model.rule_class('at-bot')
|
||||||
@@ -14,7 +14,7 @@ class AtBotRule(rule_model.GroupRespondRule):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from .. import rule as rule_model
|
from .. import rule as rule_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ....platform.types import message as platform_message
|
from ....platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@rule_model.rule_class('prefix')
|
@rule_model.rule_class('prefix')
|
||||||
@@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
prefixes = rule_dict['prefix']
|
prefixes = rule_dict['prefix']
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ import random
|
|||||||
|
|
||||||
from .. import rule as rule_model
|
from .. import rule as rule_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ....platform.types import message as platform_message
|
from ....platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@rule_model.rule_class('random')
|
@rule_model.rule_class('random')
|
||||||
@@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
random_rate = rule_dict['random']
|
random_rate = rule_dict['random']
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ import re
|
|||||||
|
|
||||||
from .. import rule as rule_model
|
from .. import rule as rule_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ....platform.types import message as platform_message
|
from ....platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@rule_model.rule_class('regexp')
|
@rule_model.rule_class('regexp')
|
||||||
@@ -14,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
regexps = rule_dict['regexp']
|
regexps = rule_dict['regexp']
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import app
|
||||||
from . import entities
|
from . import entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregistered_stages: dict[str, type[PipelineStage]] = {}
|
preregistered_stages: dict[str, type[PipelineStage]] = {}
|
||||||
@@ -33,7 +34,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> typing.Union[
|
) -> typing.Union[
|
||||||
entities.StageProcessResult,
|
entities.StageProcessResult,
|
||||||
|
|||||||
@@ -2,12 +2,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
|
||||||
from ...core import entities as core_entities
|
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from .. import stage
|
from .. import stage
|
||||||
from ...plugin import events
|
from ...plugin import events
|
||||||
from ...platform.types import message as platform_message
|
from ...platform.types import message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('ResponseWrapper')
|
@stage.stage_class('ResponseWrapper')
|
||||||
@@ -25,7 +24,7 @@ class ResponseWrapper(stage.PipelineStage):
|
|||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
|
|||||||
@@ -3,15 +3,14 @@ from __future__ import annotations
|
|||||||
# MessageSource的适配器
|
# MessageSource的适配器
|
||||||
import typing
|
import typing
|
||||||
import abc
|
import abc
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
from ..core import app
|
|
||||||
from .types import message as platform_message
|
from .types import message as platform_message
|
||||||
from .types import events as platform_events
|
from .types import events as platform_events
|
||||||
from .logger import EventLogger
|
from .logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
class MessagePlatformAdapter(pydantic.BaseModel, metaclass=abc.ABCMeta):
|
||||||
"""消息平台适配器基类"""
|
"""消息平台适配器基类"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@@ -21,11 +20,9 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
config: dict
|
config: dict
|
||||||
|
|
||||||
ap: app.Application
|
logger: EventLogger = pydantic.Field(exclude=True)
|
||||||
|
|
||||||
logger: EventLogger
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
|
||||||
"""初始化适配器
|
"""初始化适配器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -33,7 +30,6 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
|||||||
ap (app.Application): 应用上下文
|
ap (app.Application): 应用上下文
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ from ..entity.errors import platform as platform_errors
|
|||||||
|
|
||||||
from .logger import EventLogger
|
from .logger import EventLogger
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
|
||||||
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
|
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
|
||||||
from . import types as mirai
|
from . import types as mirai
|
||||||
|
|
||||||
@@ -73,7 +75,7 @@ class RuntimeBot:
|
|||||||
|
|
||||||
await self.ap.query_pool.add_query(
|
await self.ap.query_pool.add_query(
|
||||||
bot_uuid=self.bot_entity.uuid,
|
bot_uuid=self.bot_entity.uuid,
|
||||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||||
launcher_id=event.sender.id,
|
launcher_id=event.sender.id,
|
||||||
sender_id=event.sender.id,
|
sender_id=event.sender.id,
|
||||||
message_event=event,
|
message_event=event,
|
||||||
@@ -98,7 +100,7 @@ class RuntimeBot:
|
|||||||
|
|
||||||
await self.ap.query_pool.add_query(
|
await self.ap.query_pool.add_query(
|
||||||
bot_uuid=self.bot_entity.uuid,
|
bot_uuid=self.bot_entity.uuid,
|
||||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||||
launcher_id=event.group.id,
|
launcher_id=event.group.id,
|
||||||
sender_id=event.sender.id,
|
sender_id=event.sender.id,
|
||||||
message_event=event,
|
message_event=event,
|
||||||
@@ -172,9 +174,9 @@ class PlatformManager:
|
|||||||
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
|
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
|
||||||
webchat_adapter_inst = webchat_adapter_class(
|
webchat_adapter_inst = webchat_adapter_class(
|
||||||
{},
|
{},
|
||||||
self.ap,
|
|
||||||
webchat_logger,
|
webchat_logger,
|
||||||
)
|
)
|
||||||
|
webchat_adapter_inst.ap = self.ap
|
||||||
|
|
||||||
self.webchat_proxy_bot = RuntimeBot(
|
self.webchat_proxy_bot = RuntimeBot(
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
@@ -231,7 +233,6 @@ class PlatformManager:
|
|||||||
|
|
||||||
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
||||||
bot_entity.adapter_config,
|
bot_entity.adapter_config,
|
||||||
self.ap,
|
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import datetime
|
|||||||
import aiocqhttp
|
import aiocqhttp
|
||||||
|
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ...core import app
|
|
||||||
from ..types import message as platform_message
|
from ..types import message as platform_message
|
||||||
from ..types import events as platform_events
|
from ..types import events as platform_events
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
@@ -273,11 +272,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
config: dict
|
config: dict
|
||||||
|
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = []
|
on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = []
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
@@ -287,7 +284,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
self.config['shutdown_trigger'] = shutdown_trigger_placeholder
|
self.config['shutdown_trigger'] = shutdown_trigger_placeholder
|
||||||
|
|
||||||
self.ap = ap
|
|
||||||
self.on_websocket_connection_event_cache = []
|
self.on_websocket_connection_event_cache = []
|
||||||
|
|
||||||
if 'access-token' in config:
|
if 'access-token' in config:
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from libs.dingtalk_api.dingtalkevent import DingTalkEvent
|
|||||||
from pkg.platform.types import message as platform_message
|
from pkg.platform.types import message as platform_message
|
||||||
from pkg.platform.adapter import MessagePlatformAdapter
|
from pkg.platform.adapter import MessagePlatformAdapter
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ...core import app
|
|
||||||
from ..types import events as platform_events
|
from ..types import events as platform_events
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
from libs.dingtalk_api.api import DingTalkClient
|
from libs.dingtalk_api.api import DingTalkClient
|
||||||
@@ -94,15 +93,13 @@ class DingTalkEventConverter(adapter.EventConverter):
|
|||||||
|
|
||||||
class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||||
bot: DingTalkClient
|
bot: DingTalkClient
|
||||||
ap: app.Application
|
|
||||||
bot_account_id: str
|
bot_account_id: str
|
||||||
message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
|
message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
|
||||||
event_converter: DingTalkEventConverter = DingTalkEventConverter()
|
event_converter: DingTalkEventConverter = DingTalkEventConverter()
|
||||||
config: dict
|
config: dict
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
required_keys = [
|
required_keys = [
|
||||||
'client_id',
|
'client_id',
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import datetime
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ...core import app
|
|
||||||
from ..types import message as platform_message
|
from ..types import message as platform_message
|
||||||
from ..types import events as platform_events
|
from ..types import events as platform_events
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
@@ -161,8 +160,6 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
config: dict
|
config: dict
|
||||||
|
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
message_converter: DiscordMessageConverter = DiscordMessageConverter()
|
message_converter: DiscordMessageConverter = DiscordMessageConverter()
|
||||||
event_converter: DiscordEventConverter = DiscordEventConverter()
|
event_converter: DiscordEventConverter = DiscordEventConverter()
|
||||||
|
|
||||||
@@ -171,9 +168,8 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
|||||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||||
] = {}
|
] = {}
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
self.bot_account_id = self.config['client_id']
|
self.bot_account_id = self.config['client_id']
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import quart
|
|||||||
from lark_oapi.api.im.v1 import *
|
from lark_oapi.api.im.v1 import *
|
||||||
|
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ...core import app
|
|
||||||
from ..types import message as platform_message
|
from ..types import message as platform_message
|
||||||
from ..types import events as platform_events
|
from ..types import events as platform_events
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
@@ -337,11 +336,9 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
config: dict
|
config: dict
|
||||||
quart_app: quart.Quart
|
quart_app: quart.Quart
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.quart_app = quart.Quart(__name__)
|
self.quart_app = quart.Quart(__name__)
|
||||||
self.listeners = {}
|
self.listeners = {}
|
||||||
@@ -351,8 +348,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
|||||||
try:
|
try:
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
|
|
||||||
self.ap.logger.debug(f'Lark callback event: {data}')
|
|
||||||
|
|
||||||
if 'encrypt' in data:
|
if 'encrypt' in data:
|
||||||
cipher = AESCipher(self.config['encrypt-key'])
|
cipher = AESCipher(self.config['encrypt-key'])
|
||||||
data = cipher.decrypt_string(data['encrypt'])
|
data = cipher.decrypt_string(data['encrypt'])
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 25 KiB After Width: | Height: | Size: 25 KiB |
@@ -11,16 +11,16 @@ import threading
|
|||||||
import quart
|
import quart
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from .. import adapter
|
from ... import adapter
|
||||||
from ...core import app
|
from ....core import app
|
||||||
from ..types import message as platform_message
|
from ...types import message as platform_message
|
||||||
from ..types import events as platform_events
|
from ...types import events as platform_events
|
||||||
from ..types import entities as platform_entities
|
from ...types import entities as platform_entities
|
||||||
from ...utils import image
|
from ....utils import image
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from ..logger import EventLogger
|
from ...logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
class GewechatMessageConverter(adapter.MessageConverter):
|
class GewechatMessageConverter(adapter.MessageConverter):
|
||||||
@@ -491,7 +491,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
|||||||
async def gewechat_callback():
|
async def gewechat_callback():
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
# print(json.dumps(data, indent=4, ensure_ascii=False))
|
# print(json.dumps(data, indent=4, ensure_ascii=False))
|
||||||
self.ap.logger.debug(f'Gewechat callback event: {data}')
|
await self.logger.debug(f'Gewechat callback event: {data}')
|
||||||
|
|
||||||
if 'data' in data:
|
if 'data' in data:
|
||||||
data['Data'] = data['data']
|
data['Data'] = data['data']
|
||||||
@@ -601,7 +601,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
|||||||
if handler := handler_map.get(msg['type']):
|
if handler := handler_map.get(msg['type']):
|
||||||
handler(msg)
|
handler(msg)
|
||||||
else:
|
else:
|
||||||
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
await self.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||||
@@ -656,9 +656,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
self.config['app_id'] = app_id
|
self.config['app_id'] = app_id
|
||||||
|
|
||||||
self.ap.logger.info(f'Gewechat 登录成功,app_id: {app_id}')
|
print(f'Gewechat 登录成功,app_id: {app_id}')
|
||||||
|
|
||||||
self.ap.platform_mgr.write_back_config('gewechat', self, self.config)
|
|
||||||
|
|
||||||
# 获取 nickname
|
# 获取 nickname
|
||||||
profile = self.bot.get_profile(self.config['app_id'])
|
profile = self.bot.get_profile(self.config['app_id'])
|
||||||
|
Before Width: | Height: | Size: 274 KiB After Width: | Height: | Size: 274 KiB |
@@ -9,12 +9,12 @@ import traceback
|
|||||||
import nakuru
|
import nakuru
|
||||||
import nakuru.entities.components as nkc
|
import nakuru.entities.components as nkc
|
||||||
|
|
||||||
from .. import adapter as adapter_model
|
from ... import adapter as adapter_model
|
||||||
from ...pipeline.longtext.strategies import forward
|
from ....pipeline.longtext.strategies import forward
|
||||||
from ...platform.types import message as platform_message
|
from ...types import message as platform_message
|
||||||
from ...platform.types import entities as platform_entities
|
from ...types import entities as platform_entities
|
||||||
from ...platform.types import events as platform_events
|
from ...types import events as platform_events
|
||||||
from ..logger import EventLogger
|
from ...logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||||
@@ -262,7 +262,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
|||||||
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
|
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
|
||||||
|
|
||||||
# 包装函数
|
# 包装函数
|
||||||
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls):
|
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): # type: ignore
|
||||||
await callback(self.event_converter.target2yiri(source), self)
|
await callback(self.event_converter.target2yiri(source), self)
|
||||||
|
|
||||||
# 将包装函数和原函数的对应关系存入列表
|
# 将包装函数和原函数的对应关系存入列表
|
||||||
@@ -322,7 +322,6 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
|
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
|
||||||
await self.bot._run()
|
await self.bot._run()
|
||||||
self.ap.logger.info('运行 Nakuru 适配器')
|
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
@@ -10,14 +10,14 @@ import botpy
|
|||||||
import botpy.message as botpy_message
|
import botpy.message as botpy_message
|
||||||
import botpy.types.message as botpy_message_type
|
import botpy.types.message as botpy_message_type
|
||||||
|
|
||||||
from .. import adapter as adapter_model
|
from ... import adapter as adapter_model
|
||||||
from ...pipeline.longtext.strategies import forward
|
from ....pipeline.longtext.strategies import forward
|
||||||
from ...core import app
|
from ....core import app
|
||||||
from ...config import manager as cfg_mgr
|
from ....config import manager as cfg_mgr
|
||||||
from ...platform.types import entities as platform_entities
|
from ...types import entities as platform_entities
|
||||||
from ...platform.types import events as platform_events
|
from ...types import events as platform_events
|
||||||
from ...platform.types import message as platform_message
|
from ...types import message as platform_message
|
||||||
from ..logger import EventLogger
|
from ...logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
class OfficialGroupMessage(platform_events.GroupMessage):
|
class OfficialGroupMessage(platform_events.GroupMessage):
|
||||||
@@ -519,7 +519,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
|||||||
|
|
||||||
self.cfg['ret_coro'] = True
|
self.cfg['ret_coro'] = True
|
||||||
|
|
||||||
self.ap.logger.info('运行 QQ 官方适配器')
|
await self.logger.info('运行 QQ 官方适配器')
|
||||||
await (await self.bot.start(**self.cfg))
|
await (await self.bot.start(**self.cfg))
|
||||||
|
|
||||||
async def kill(self) -> bool:
|
async def kill(self) -> bool:
|
||||||
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
@@ -10,7 +10,6 @@ from libs.official_account_api.oaevent import OAEvent
|
|||||||
from libs.official_account_api.api import OAClient
|
from libs.official_account_api.api import OAClient
|
||||||
from libs.official_account_api.api import OAClientForLongerResponse
|
from libs.official_account_api.api import OAClientForLongerResponse
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ...core import app
|
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
from ...command.errors import ParamNotEnoughError
|
from ...command.errors import ParamNotEnoughError
|
||||||
from ..logger import EventLogger
|
from ..logger import EventLogger
|
||||||
@@ -58,15 +57,13 @@ class OAEventConverter(adapter.EventConverter):
|
|||||||
|
|
||||||
class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||||
bot: OAClient | OAClientForLongerResponse
|
bot: OAClient | OAClientForLongerResponse
|
||||||
ap: app.Application
|
|
||||||
bot_account_id: str
|
bot_account_id: str
|
||||||
message_converter: OAMessageConverter = OAMessageConverter()
|
message_converter: OAMessageConverter = OAMessageConverter()
|
||||||
event_converter: OAEventConverter = OAEventConverter()
|
event_converter: OAEventConverter = OAEventConverter()
|
||||||
config: dict
|
config: dict
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
required_keys = [
|
required_keys = [
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import datetime
|
|||||||
from pkg.platform.adapter import MessagePlatformAdapter
|
from pkg.platform.adapter import MessagePlatformAdapter
|
||||||
from pkg.platform.types import events as platform_events, message as platform_message
|
from pkg.platform.types import events as platform_events, message as platform_message
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ...core import app
|
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
from ...command.errors import ParamNotEnoughError
|
from ...command.errors import ParamNotEnoughError
|
||||||
from libs.qq_official_api.api import QQOfficialClient
|
from libs.qq_official_api.api import QQOfficialClient
|
||||||
@@ -134,15 +133,13 @@ class QQOfficialEventConverter(adapter.EventConverter):
|
|||||||
|
|
||||||
class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||||
bot: QQOfficialClient
|
bot: QQOfficialClient
|
||||||
ap: app.Application
|
|
||||||
config: dict
|
config: dict
|
||||||
bot_account_id: str
|
bot_account_id: str
|
||||||
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
|
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
|
||||||
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
|
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
required_keys = [
|
required_keys = [
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from libs.slack_api.api import SlackClient
|
|||||||
from pkg.platform.adapter import MessagePlatformAdapter
|
from pkg.platform.adapter import MessagePlatformAdapter
|
||||||
from pkg.platform.types import events as platform_events, message as platform_message
|
from pkg.platform.types import events as platform_events, message as platform_message
|
||||||
from libs.slack_api.slackevent import SlackEvent
|
from libs.slack_api.slackevent import SlackEvent
|
||||||
from pkg.core import app
|
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
from ...command.errors import ParamNotEnoughError
|
from ...command.errors import ParamNotEnoughError
|
||||||
@@ -86,15 +85,13 @@ class SlackEventConverter(adapter.EventConverter):
|
|||||||
|
|
||||||
class SlackAdapter(adapter.MessagePlatformAdapter):
|
class SlackAdapter(adapter.MessagePlatformAdapter):
|
||||||
bot: SlackClient
|
bot: SlackClient
|
||||||
ap: app.Application
|
|
||||||
bot_account_id: str
|
bot_account_id: str
|
||||||
message_converter: SlackMessageConverter = SlackMessageConverter()
|
message_converter: SlackMessageConverter = SlackMessageConverter()
|
||||||
event_converter: SlackEventConverter = SlackEventConverter()
|
event_converter: SlackEventConverter = SlackEventConverter()
|
||||||
config: dict
|
config: dict
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
required_keys = [
|
required_keys = [
|
||||||
'bot_token',
|
'bot_token',
|
||||||
|
|||||||
@@ -10,10 +10,7 @@ import traceback
|
|||||||
import base64
|
import base64
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from lark_oapi.api.im.v1 import *
|
|
||||||
|
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ...core import app
|
|
||||||
from ..types import message as platform_message
|
from ..types import message as platform_message
|
||||||
from ..types import events as platform_events
|
from ..types import events as platform_events
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
@@ -141,16 +138,14 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
|||||||
event_converter: TelegramEventConverter = TelegramEventConverter()
|
event_converter: TelegramEventConverter = TelegramEventConverter()
|
||||||
|
|
||||||
config: dict
|
config: dict
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
listeners: typing.Dict[
|
listeners: typing.Dict[
|
||||||
typing.Type[platform_events.Event],
|
typing.Type[platform_events.Event],
|
||||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||||
] = {}
|
] = {}
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
|
|||||||
@@ -44,13 +44,14 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
|||||||
webchat_person_session: WebChatSession
|
webchat_person_session: WebChatSession
|
||||||
webchat_group_session: WebChatSession
|
webchat_group_session: WebChatSession
|
||||||
|
|
||||||
|
ap: app.Application # set by bot manager
|
||||||
|
|
||||||
listeners: typing.Dict[
|
listeners: typing.Dict[
|
||||||
typing.Type[platform_events.Event],
|
typing.Type[platform_events.Event],
|
||||||
typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None],
|
typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None],
|
||||||
] = {}
|
] = {}
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|||||||
@@ -488,6 +488,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
|
logger: EventLogger
|
||||||
|
|
||||||
message_converter: WeChatPadMessageConverter
|
message_converter: WeChatPadMessageConverter
|
||||||
event_converter: WeChatPadEventConverter
|
event_converter: WeChatPadEventConverter
|
||||||
|
|
||||||
@@ -507,8 +509,6 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
async def ws_message(self, data):
|
async def ws_message(self, data):
|
||||||
"""处理接收到的消息"""
|
"""处理接收到的消息"""
|
||||||
# self.ap.logger.debug(f"Gewechat callback event: {data}")
|
|
||||||
# print(data)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
|
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
|
||||||
@@ -571,9 +571,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
if handler := handler_map.get(msg['type']):
|
if handler := handler_map.get(msg['type']):
|
||||||
handler(msg)
|
handler(msg)
|
||||||
# self.ap.logger.warning(f"未处理的消息类型: {ret}")
|
|
||||||
else:
|
else:
|
||||||
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
print(f'未处理的消息类型: {msg["type"]}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||||
@@ -615,7 +614,6 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
|||||||
if self.config['token']:
|
if self.config['token']:
|
||||||
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
|
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
|
||||||
data = self.bot.get_login_status()
|
data = self.bot.get_login_status()
|
||||||
self.ap.logger.info(data)
|
|
||||||
if data['Code'] == 300 and data['Text'] == '你已退出微信':
|
if data['Code'] == 300 and data['Text'] == '你已退出微信':
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
|
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
|
||||||
@@ -635,7 +633,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
|||||||
self.config['token'] = response.json()['Data'][0]
|
self.config['token'] = response.json()['Data'][0]
|
||||||
|
|
||||||
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger)
|
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger)
|
||||||
self.ap.logger.info(self.config['token'])
|
await self.logger.info(self.config['token'])
|
||||||
thread_1 = threading.Event()
|
thread_1 = threading.Event()
|
||||||
|
|
||||||
def wechat_login_process():
|
def wechat_login_process():
|
||||||
@@ -643,10 +641,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
|||||||
# login_data =self.bot.get_login_qr()
|
# login_data =self.bot.get_login_qr()
|
||||||
|
|
||||||
# url = login_data['Data']["QrCodeUrl"]
|
# url = login_data['Data']["QrCodeUrl"]
|
||||||
# self.ap.logger.info(login_data)
|
|
||||||
|
|
||||||
profile = self.bot.get_profile()
|
profile = self.bot.get_profile()
|
||||||
self.ap.logger.info(profile)
|
self.logger.info(profile)
|
||||||
|
|
||||||
self.bot_account_id = profile['Data']['userInfo']['nickName']['str']
|
self.bot_account_id = profile['Data']['userInfo']['nickName']['str']
|
||||||
self.config['wxid'] = profile['Data']['userInfo']['userName']['str']
|
self.config['wxid'] = profile['Data']['userInfo']['userName']['str']
|
||||||
@@ -658,27 +655,26 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
|||||||
def connect_websocket_sync() -> None:
|
def connect_websocket_sync() -> None:
|
||||||
thread_1.wait()
|
thread_1.wait()
|
||||||
uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}'
|
uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}'
|
||||||
self.ap.logger.info(f'Connecting to WebSocket: {uri}')
|
print(f'Connecting to WebSocket: {uri}')
|
||||||
|
|
||||||
def on_message(ws, message):
|
def on_message(ws, message):
|
||||||
try:
|
try:
|
||||||
data = json.loads(message)
|
data = json.loads(message)
|
||||||
self.ap.logger.debug(f'Received message: {data}')
|
|
||||||
# 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法
|
# 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法
|
||||||
asyncio.run(self.ws_message(data))
|
asyncio.run(self.ws_message(data))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
self.ap.logger.error(f'Non-JSON message: {message[:100]}...')
|
print(f'Non-JSON message: {message[:100]}...')
|
||||||
|
|
||||||
def on_error(ws, error):
|
def on_error(ws, error):
|
||||||
self.ap.logger.error(f'WebSocket error: {str(error)[:200]}')
|
print(f'WebSocket error: {str(error)[:200]}')
|
||||||
|
|
||||||
def on_close(ws, close_status_code, close_msg):
|
def on_close(ws, close_status_code, close_msg):
|
||||||
self.ap.logger.info('WebSocket closed, reconnecting...')
|
print('WebSocket closed, reconnecting...')
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
connect_websocket_sync() # 自动重连
|
connect_websocket_sync() # 自动重连
|
||||||
|
|
||||||
def on_open(ws):
|
def on_open(ws):
|
||||||
self.ap.logger.info('WebSocket connected successfully!')
|
print('WebSocket connected successfully!')
|
||||||
|
|
||||||
ws = websocket.WebSocketApp(
|
ws = websocket.WebSocketApp(
|
||||||
uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open
|
uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open
|
||||||
@@ -689,10 +685,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
|
|||||||
# connect_websocket_sync()
|
# connect_websocket_sync()
|
||||||
|
|
||||||
# 这行代码会在WebSocket连接断开后才会执行
|
# 这行代码会在WebSocket连接断开后才会执行
|
||||||
# self.ap.logger.info("WebSocket client thread started")
|
|
||||||
thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True)
|
thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
self.ap.logger.info('WebSocket client thread started')
|
self.logger.info('WebSocket client thread started')
|
||||||
|
|
||||||
async def kill(self) -> bool:
|
async def kill(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from pkg.platform.adapter import MessagePlatformAdapter
|
|||||||
from pkg.platform.types import events as platform_events, message as platform_message
|
from pkg.platform.types import events as platform_events, message as platform_message
|
||||||
from libs.wecom_api.wecomevent import WecomEvent
|
from libs.wecom_api.wecomevent import WecomEvent
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ...core import app
|
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
from ...command.errors import ParamNotEnoughError
|
from ...command.errors import ParamNotEnoughError
|
||||||
from ...utils import image
|
from ...utils import image
|
||||||
@@ -129,15 +128,13 @@ class WecomEventConverter:
|
|||||||
|
|
||||||
class WecomAdapter(adapter.MessagePlatformAdapter):
|
class WecomAdapter(adapter.MessagePlatformAdapter):
|
||||||
bot: WecomClient
|
bot: WecomClient
|
||||||
ap: app.Application
|
|
||||||
bot_account_id: str
|
bot_account_id: str
|
||||||
message_converter: WecomMessageConverter = WecomMessageConverter()
|
message_converter: WecomMessageConverter = WecomMessageConverter()
|
||||||
event_converter: WecomEventConverter = WecomEventConverter()
|
event_converter: WecomEventConverter = WecomEventConverter()
|
||||||
config: dict
|
config: dict
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
required_keys = [
|
required_keys = [
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from libs.wecom_customer_service_api.api import WecomCSClient
|
|||||||
from pkg.platform.adapter import MessagePlatformAdapter
|
from pkg.platform.adapter import MessagePlatformAdapter
|
||||||
from pkg.platform.types import events as platform_events, message as platform_message
|
from pkg.platform.types import events as platform_events, message as platform_message
|
||||||
from libs.wecom_customer_service_api.wecomcsevent import WecomCSEvent
|
from libs.wecom_customer_service_api.wecomcsevent import WecomCSEvent
|
||||||
from pkg.core import app
|
|
||||||
from .. import adapter
|
from .. import adapter
|
||||||
from ..types import entities as platform_entities
|
from ..types import entities as platform_entities
|
||||||
from ...command.errors import ParamNotEnoughError
|
from ...command.errors import ParamNotEnoughError
|
||||||
@@ -119,15 +118,13 @@ class WecomEventConverter:
|
|||||||
|
|
||||||
class WecomCSAdapter(adapter.MessagePlatformAdapter):
|
class WecomCSAdapter(adapter.MessagePlatformAdapter):
|
||||||
bot: WecomCSClient
|
bot: WecomCSClient
|
||||||
ap: app.Application
|
|
||||||
bot_account_id: str
|
bot_account_id: str
|
||||||
message_converter: WecomMessageConverter = WecomMessageConverter()
|
message_converter: WecomMessageConverter = WecomMessageConverter()
|
||||||
event_converter: WecomEventConverter = WecomEventConverter()
|
event_converter: WecomEventConverter = WecomEventConverter()
|
||||||
config: dict
|
config: dict
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
required_keys = [
|
required_keys = [
|
||||||
|
|||||||
@@ -4,16 +4,16 @@ import typing
|
|||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic.v1 as pydantic
|
||||||
|
|
||||||
from ..core import entities as core_entities
|
|
||||||
from ..provider import entities as llm_entities
|
from ..provider import entities as llm_entities
|
||||||
from ..platform.types import message as platform_message
|
from ..platform.types import message as platform_message
|
||||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class BaseEventModel(pydantic.BaseModel):
|
class BaseEventModel(pydantic.BaseModel):
|
||||||
"""事件模型基类"""
|
"""事件模型基类"""
|
||||||
|
|
||||||
query: typing.Union[core_entities.Query, None]
|
query: typing.Union[pipeline_query.Query, None]
|
||||||
"""此次请求的query对象,非请求过程的事件时为None"""
|
"""此次请求的query对象,非请求过程的事件时为None"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ import importlib
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from .. import loader, events, context, models
|
from .. import loader, events, context, models
|
||||||
from ...core import entities as core_entities
|
|
||||||
from langbot_plugin.api.entities.builtin.resource import tool as resource_tool
|
from langbot_plugin.api.entities.builtin.resource import tool as resource_tool
|
||||||
from ...utils import funcschema
|
from ...utils import funcschema
|
||||||
from ...discover import engine as discover_engine
|
from ...discover import engine as discover_engine
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class PluginLoader(loader.PluginLoader):
|
class PluginLoader(loader.PluginLoader):
|
||||||
@@ -98,7 +98,7 @@ class PluginLoader(loader.PluginLoader):
|
|||||||
function_schema = funcschema.get_func_schema(func)
|
function_schema = funcschema.get_func_schema(func)
|
||||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||||
|
|
||||||
async def handler(plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs):
|
async def handler(plugin: context.BasePlugin, query: pipeline_query.Query, *args, **kwargs):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
llm_function = resource_tool.LLMTool(
|
llm_function = resource_tool.LLMTool(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
|
|
||||||
from pkg.provider import entities
|
from pkg.provider import entities
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
|
|
||||||
from . import requester
|
from . import requester
|
||||||
from . import token
|
from . import token
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import abc
|
|||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...core import app
|
from ...core import app
|
||||||
from ...core import entities as core_entities
|
|
||||||
from .. import entities as llm_entities
|
from .. import entities as llm_entities
|
||||||
from ...entity.persistence import model as persistence_model
|
from ...entity.persistence import model as persistence_model
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
from . import token
|
from . import token
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class RuntimeLLMModel:
|
class RuntimeLLMModel:
|
||||||
@@ -56,7 +56,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def invoke_llm(
|
async def invoke_llm(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
model: RuntimeLLMModel,
|
model: RuntimeLLMModel,
|
||||||
messages: typing.List[llm_entities.Message],
|
messages: typing.List[llm_entities.Message],
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ import httpx
|
|||||||
|
|
||||||
from .. import errors, requester
|
from .. import errors, requester
|
||||||
|
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ... import entities as llm_entities
|
from ... import entities as llm_entities
|
||||||
from ....utils import image
|
from ....utils import image
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessages(requester.LLMAPIRequester):
|
class AnthropicMessages(requester.LLMAPIRequester):
|
||||||
@@ -48,7 +48,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
|
|||||||
|
|
||||||
async def invoke_llm(
|
async def invoke_llm(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
model: requester.RuntimeLLMModel,
|
model: requester.RuntimeLLMModel,
|
||||||
messages: typing.List[llm_entities.Message],
|
messages: typing.List[llm_entities.Message],
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ import openai.types.chat.chat_completion as chat_completion
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .. import errors, requester
|
from .. import errors, requester
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ... import entities as llm_entities
|
from ... import entities as llm_entities
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletions(requester.LLMAPIRequester):
|
class OpenAIChatCompletions(requester.LLMAPIRequester):
|
||||||
@@ -60,7 +60,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
|||||||
|
|
||||||
async def _closure(
|
async def _closure(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
req_messages: list[dict],
|
req_messages: list[dict],
|
||||||
use_model: requester.RuntimeLLMModel,
|
use_model: requester.RuntimeLLMModel,
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
use_funcs: list[resource_tool.LLMTool] = None,
|
||||||
@@ -101,7 +101,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
|
|||||||
|
|
||||||
async def invoke_llm(
|
async def invoke_llm(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
model: requester.RuntimeLLMModel,
|
model: requester.RuntimeLLMModel,
|
||||||
messages: typing.List[llm_entities.Message],
|
messages: typing.List[llm_entities.Message],
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ import typing
|
|||||||
|
|
||||||
from . import chatcmpl
|
from . import chatcmpl
|
||||||
from .. import errors, requester
|
from .. import errors, requester
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ... import entities as llm_entities
|
from ... import entities as llm_entities
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||||
@@ -19,7 +19,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|||||||
|
|
||||||
async def _closure(
|
async def _closure(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
req_messages: list[dict],
|
req_messages: list[dict],
|
||||||
use_model: requester.RuntimeLLMModel,
|
use_model: requester.RuntimeLLMModel,
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
use_funcs: list[resource_tool.LLMTool] = None,
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import typing
|
|||||||
|
|
||||||
from . import chatcmpl
|
from . import chatcmpl
|
||||||
from .. import requester
|
from .. import requester
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ... import entities as llm_entities
|
from ... import entities as llm_entities
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||||
@@ -20,7 +20,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|||||||
|
|
||||||
async def _closure(
|
async def _closure(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
req_messages: list[dict],
|
req_messages: list[dict],
|
||||||
use_model: requester.RuntimeLLMModel,
|
use_model: requester.RuntimeLLMModel,
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
use_funcs: list[resource_tool.LLMTool] = None,
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ import openai.types.chat.chat_completion_message_tool_call as chat_completion_me
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .. import entities, errors, requester
|
from .. import entities, errors, requester
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ... import entities as llm_entities
|
from ... import entities as llm_entities
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
||||||
@@ -125,7 +125,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
|||||||
|
|
||||||
async def _closure(
|
async def _closure(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
req_messages: list[dict],
|
req_messages: list[dict],
|
||||||
use_model: requester.RuntimeLLMModel,
|
use_model: requester.RuntimeLLMModel,
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
use_funcs: list[resource_tool.LLMTool] = None,
|
||||||
@@ -166,7 +166,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
|
|||||||
|
|
||||||
async def invoke_llm(
|
async def invoke_llm(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
model: entities.LLMModelInfo,
|
model: entities.LLMModelInfo,
|
||||||
messages: typing.List[llm_entities.Message],
|
messages: typing.List[llm_entities.Message],
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import typing
|
|||||||
|
|
||||||
from . import chatcmpl
|
from . import chatcmpl
|
||||||
from .. import requester
|
from .. import requester
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ... import entities as llm_entities
|
from ... import entities as llm_entities
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||||
@@ -20,7 +20,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
|||||||
|
|
||||||
async def _closure(
|
async def _closure(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
req_messages: list[dict],
|
req_messages: list[dict],
|
||||||
use_model: requester.RuntimeLLMModel,
|
use_model: requester.RuntimeLLMModel,
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
use_funcs: list[resource_tool.LLMTool] = None,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import ollama
|
|||||||
from .. import errors, requester
|
from .. import errors, requester
|
||||||
from ... import entities as llm_entities
|
from ... import entities as llm_entities
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
REQUESTER_NAME: str = 'ollama-chat'
|
REQUESTER_NAME: str = 'ollama-chat'
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
|||||||
|
|
||||||
async def _closure(
|
async def _closure(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
req_messages: list[dict],
|
req_messages: list[dict],
|
||||||
use_model: requester.RuntimeLLMModel,
|
use_model: requester.RuntimeLLMModel,
|
||||||
use_funcs: list[resource_tool.LLMTool] = None,
|
use_funcs: list[resource_tool.LLMTool] = None,
|
||||||
@@ -105,7 +105,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
|
|||||||
|
|
||||||
async def invoke_llm(
|
async def invoke_llm(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
model: requester.RuntimeLLMModel,
|
model: requester.RuntimeLLMModel,
|
||||||
messages: typing.List[llm_entities.Message],
|
messages: typing.List[llm_entities.Message],
|
||||||
funcs: typing.List[resource_tool.LLMTool] = None,
|
funcs: typing.List[resource_tool.LLMTool] = None,
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import app
|
||||||
from . import entities as llm_entities
|
from . import entities as llm_entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregistered_runners: list[typing.Type[RequestRunner]] = []
|
preregistered_runners: list[typing.Type[RequestRunner]] = []
|
||||||
@@ -35,6 +36,6 @@ class RequestRunner(abc.ABC):
|
|||||||
self.pipeline_config = pipeline_config
|
self.pipeline_config = pipeline_config
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""运行请求"""
|
"""运行请求"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ import re
|
|||||||
import dashscope
|
import dashscope
|
||||||
|
|
||||||
from .. import runner
|
from .. import runner
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
from .. import entities as llm_entities
|
from .. import entities as llm_entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class DashscopeAPIError(Exception):
|
class DashscopeAPIError(Exception):
|
||||||
@@ -65,7 +66,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
|||||||
# 使用 re.sub() 进行替换
|
# 使用 re.sub() 进行替换
|
||||||
return pattern.sub(replacement, text)
|
return pattern.sub(replacement, text)
|
||||||
|
|
||||||
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]:
|
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
|
||||||
"""预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)"""
|
"""预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)"""
|
||||||
plain_text = ''
|
plain_text = ''
|
||||||
image_ids = []
|
image_ids = []
|
||||||
@@ -89,7 +90,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
return plain_text, image_ids
|
return plain_text, image_ids
|
||||||
|
|
||||||
async def _agent_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def _agent_messages(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""Dashscope 智能体对话请求"""
|
"""Dashscope 智能体对话请求"""
|
||||||
|
|
||||||
# 局部变量
|
# 局部变量
|
||||||
@@ -147,7 +148,9 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
|||||||
content=pending_content,
|
content=pending_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def _workflow_messages(
|
||||||
|
self, query: pipeline_query.Query
|
||||||
|
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""Dashscope 工作流对话请求"""
|
"""Dashscope 工作流对话请求"""
|
||||||
|
|
||||||
# 局部变量
|
# 局部变量
|
||||||
@@ -210,7 +213,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
|||||||
content=pending_content,
|
content=pending_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""运行"""
|
"""运行"""
|
||||||
if self.app_type == 'agent':
|
if self.app_type == 'agent':
|
||||||
async for msg in self._agent_messages(query):
|
async for msg in self._agent_messages(query):
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ import base64
|
|||||||
|
|
||||||
|
|
||||||
from .. import runner
|
from .. import runner
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
from .. import entities as llm_entities
|
from .. import entities as llm_entities
|
||||||
from ...utils import image
|
from ...utils import image
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
from libs.dify_service_api.v1 import client, errors
|
from libs.dify_service_api.v1 import client, errors
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
|
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
|
||||||
return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
|
return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
|
||||||
|
|
||||||
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]:
|
async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
|
||||||
"""预处理用户消息,提取纯文本,并将图片上传到 Dify 服务
|
"""预处理用户消息,提取纯文本,并将图片上传到 Dify 服务
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -90,7 +90,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
return plain_text, image_ids
|
return plain_text, image_ids
|
||||||
|
|
||||||
async def _chat_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def _chat_messages(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""调用聊天助手"""
|
"""调用聊天助手"""
|
||||||
cov_id = query.session.using_conversation.uuid or ''
|
cov_id = query.session.using_conversation.uuid or ''
|
||||||
query.variables['conversation_id'] = cov_id
|
query.variables['conversation_id'] = cov_id
|
||||||
@@ -152,7 +152,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||||
|
|
||||||
async def _agent_chat_messages(
|
async def _agent_chat_messages(
|
||||||
self, query: core_entities.Query
|
self, query: pipeline_query.Query
|
||||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""调用聊天助手"""
|
"""调用聊天助手"""
|
||||||
cov_id = query.session.using_conversation.uuid or ''
|
cov_id = query.session.using_conversation.uuid or ''
|
||||||
@@ -244,7 +244,9 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||||
|
|
||||||
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def _workflow_messages(
|
||||||
|
self, query: pipeline_query.Query
|
||||||
|
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""调用工作流"""
|
"""调用工作流"""
|
||||||
|
|
||||||
if not query.session.using_conversation.uuid:
|
if not query.session.using_conversation.uuid:
|
||||||
@@ -316,7 +318,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
yield msg
|
yield msg
|
||||||
|
|
||||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""运行请求"""
|
"""运行请求"""
|
||||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
|
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
|
||||||
async for msg in self._chat_messages(query):
|
async for msg in self._chat_messages(query):
|
||||||
|
|||||||
@@ -4,15 +4,15 @@ import json
|
|||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import runner
|
from .. import runner
|
||||||
from ...core import entities as core_entities
|
|
||||||
from .. import entities as llm_entities
|
from .. import entities as llm_entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@runner.runner_class('local-agent')
|
@runner.runner_class('local-agent')
|
||||||
class LocalAgentRunner(runner.RequestRunner):
|
class LocalAgentRunner(runner.RequestRunner):
|
||||||
"""本地Agent请求运行器"""
|
"""本地Agent请求运行器"""
|
||||||
|
|
||||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""运行请求"""
|
"""运行请求"""
|
||||||
pending_tool_calls = []
|
pending_tool_calls = []
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ import uuid
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from .. import runner
|
from .. import runner
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
from .. import entities as llm_entities
|
from .. import entities as llm_entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class N8nAPIError(Exception):
|
class N8nAPIError(Exception):
|
||||||
@@ -49,7 +50,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
self.header_name = self.pipeline_config['ai']['n8n-service-api'].get('header-name', '')
|
self.header_name = self.pipeline_config['ai']['n8n-service-api'].get('header-name', '')
|
||||||
self.header_value = self.pipeline_config['ai']['n8n-service-api'].get('header-value', '')
|
self.header_value = self.pipeline_config['ai']['n8n-service-api'].get('header-value', '')
|
||||||
|
|
||||||
async def _preprocess_user_message(self, query: core_entities.Query) -> str:
|
async def _preprocess_user_message(self, query: pipeline_query.Query) -> str:
|
||||||
"""预处理用户消息,提取纯文本
|
"""预处理用户消息,提取纯文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -67,7 +68,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
|
|
||||||
return plain_text
|
return plain_text
|
||||||
|
|
||||||
async def _call_webhook(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""调用n8n webhook"""
|
"""调用n8n webhook"""
|
||||||
# 生成会话ID(如果不存在)
|
# 生成会话ID(如果不存在)
|
||||||
if not query.session.using_conversation.uuid:
|
if not query.session.using_conversation.uuid:
|
||||||
@@ -153,7 +154,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
|
|||||||
self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
|
self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
|
||||||
raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
|
raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
|
||||||
|
|
||||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
"""运行请求"""
|
"""运行请求"""
|
||||||
async for msg in self._call_webhook(query):
|
async for msg in self._call_webhook(query):
|
||||||
yield msg
|
yield msg
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message, prompt as provider_prompt
|
from langbot_plugin.api.entities.builtin.provider import message as provider_message, prompt as provider_prompt
|
||||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
@@ -21,7 +22,7 @@ class SessionManager:
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_session(self, query: core_entities.Query) -> provider_session.Session:
|
async def get_session(self, query: pipeline_query.Query) -> provider_session.Session:
|
||||||
"""获取会话"""
|
"""获取会话"""
|
||||||
for session in self.session_list:
|
for session in self.session_list:
|
||||||
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
|
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
|
||||||
@@ -39,7 +40,7 @@ class SessionManager:
|
|||||||
|
|
||||||
async def get_conversation(
|
async def get_conversation(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
session: provider_session.Session,
|
session: provider_session.Session,
|
||||||
prompt_config: list[dict],
|
prompt_config: list[dict],
|
||||||
pipeline_uuid: str,
|
pipeline_uuid: str,
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregistered_loaders: list[typing.Type[ToolLoader]] = []
|
preregistered_loaders: list[typing.Type[ToolLoader]] = []
|
||||||
@@ -45,7 +46,7 @@ class ToolLoader(abc.ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
|
||||||
"""执行工具调用"""
|
"""执行工具调用"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ from mcp.client.stdio import stdio_client
|
|||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
from .. import loader
|
from .. import loader
|
||||||
from ....core import app, entities as core_entities
|
from ....core import app
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class RuntimeMCPSession:
|
class RuntimeMCPSession:
|
||||||
@@ -83,7 +84,7 @@ class RuntimeMCPSession:
|
|||||||
|
|
||||||
for tool in tools.tools:
|
for tool in tools.tools:
|
||||||
|
|
||||||
async def func(query: core_entities.Query, *, _tool=tool, **kwargs):
|
async def func(query: pipeline_query.Query, *, _tool=tool, **kwargs):
|
||||||
result = await self.session.call_tool(_tool.name, kwargs)
|
result = await self.session.call_tool(_tool.name, kwargs)
|
||||||
if result.isError:
|
if result.isError:
|
||||||
raise Exception(result.content[0].text)
|
raise Exception(result.content[0].text)
|
||||||
@@ -144,7 +145,7 @@ class MCPLoader(loader.ToolLoader):
|
|||||||
async def has_tool(self, name: str) -> bool:
|
async def has_tool(self, name: str) -> bool:
|
||||||
return name in [f.name for f in self._last_listed_functions]
|
return name in [f.name for f in self._last_listed_functions]
|
||||||
|
|
||||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
|
||||||
for server_name, session in self.sessions.items():
|
for server_name, session in self.sessions.items():
|
||||||
for function in session.functions:
|
for function in session.functions:
|
||||||
if function.name == name:
|
if function.name == name:
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ import typing
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from .. import loader
|
from .. import loader
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ....plugin import context as plugin_context
|
from ....plugin import context as plugin_context
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@loader.loader_class('plugin-tool-loader')
|
@loader.loader_class('plugin-tool-loader')
|
||||||
@@ -49,7 +49,7 @@ class PluginToolLoader(loader.ToolLoader):
|
|||||||
return function, plugin.plugin_inst
|
return function, plugin.plugin_inst
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
|
||||||
try:
|
try:
|
||||||
function, plugin = await self._get_function_and_plugin(name)
|
function, plugin = await self._get_function_and_plugin(name)
|
||||||
if function is None:
|
if function is None:
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
from . import loader as tools_loader
|
from . import loader as tools_loader
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
from . import loaders
|
from . import loaders
|
||||||
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(loaders)
|
importutil.import_modules_in_pkg(loaders)
|
||||||
|
|
||||||
@@ -90,7 +91,7 @@ class ToolManager:
|
|||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
async def execute_func_call(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any:
|
async def execute_func_call(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
|
||||||
"""执行函数调用"""
|
"""执行函数调用"""
|
||||||
|
|
||||||
for loader in self.loaders:
|
for loader in self.loaders:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import os
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ..core import app
|
from ..core import app
|
||||||
|
|||||||
Reference in New Issue
Block a user