Compare commits

...

39 Commits

Author SHA1 Message Date
RockChinQ
9e6a01fefd chore: release v3.2.2 2024-05-31 19:20:34 +08:00
RockChinQ
933471b4d9 perf: 启动失败时输出完整traceback (#799) 2024-05-31 15:37:56 +08:00
RockChinQ
f81808d239 perf: 添加JSON配置文件语法检查 (#796) 2024-05-29 21:11:21 +08:00
RockChinQ
96832b6f7d perf: 忽略空的 assistant content 消息 (#795) 2024-05-29 21:00:48 +08:00
Junyan Qin
e2eb0a84b0 Merge pull request #797 from RockChinQ/feat/context-truncater
Feat: 消息截断器
2024-05-29 20:38:14 +08:00
RockChinQ
c8eb2e3376 feat: 消息截断器 2024-05-29 20:34:49 +08:00
Junyan Qin
21fe5822f9 Merge pull request #794 from RockChinQ/perf/advanced-fixwin
Feat: fixwin限速支持设置窗口大小
2024-05-26 10:33:49 +08:00
RockChinQ
d49cc9a7a3 feat: fixwin限速支持设置窗口大小 (#791) 2024-05-26 10:29:10 +08:00
Junyan Qin
910d0bfae1 Update README.md 2024-05-25 12:27:27 +08:00
RockChinQ
d6761949ca chore: release v3.2.1 2024-05-23 16:29:26 +08:00
RockChinQ
6afac1f593 feat: 允许指定遥测服务器url 2024-05-23 16:25:51 +08:00
RockChinQ
4d1a270d22 doc: 添加qcg-center源码链接 2024-05-23 16:16:13 +08:00
Junyan Qin
a7888f5536 Merge pull request #787 from RockChinQ/perf/claude-ability
Perf: Claude 的能力完善支持
2024-05-22 20:33:39 +08:00
RockChinQ
b9049e91cf chore: 同步 llm-models.json 2024-05-22 20:31:46 +08:00
RockChinQ
7db56c8e77 feat: claude 支持视觉 2024-05-22 20:09:29 +08:00
Junyan Qin
50563cb957 Merge pull request #785 from RockChinQ/fix/msg-chain-compability
Fix: 修复 query.resp_messages 对插件reply的兼容性
2024-05-18 20:13:50 +08:00
RockChinQ
18ae2299a7 fix: 修复 query.resp_messages 对插件reply的兼容性 2024-05-18 20:08:48 +08:00
RockChinQ
7463e0aab9 perf: 删除多个地方残留的 config.py 字段 (#781) 2024-05-18 18:52:45 +08:00
Junyan Qin
c92d47bb95 Merge pull request #779 from jerryliang122/master
修复aiocqhttp的图片错误
2024-05-17 17:05:58 +08:00
RockChinQ
0b1af7df91 perf: 统一判断方式 2024-05-17 17:05:20 +08:00
jerryliang122
a9104eb2da 通过base64编码发送,修复cqhttp无法发送图片 2024-05-17 08:20:06 +00:00
RockChinQ
abbd15d5cc chore: release v3.2.0.1 2024-05-17 09:48:20 +08:00
RockChinQ
aadfa14d59 fix: claude 请求失败 2024-05-17 09:46:06 +08:00
Junyan Qin
4cd10bbe25 Update README.md 2024-05-16 22:17:46 +08:00
RockChinQ
1d4a6b71ab chore: release v3.2.0 2024-05-16 21:22:40 +08:00
Junyan Qin
a7f830dd73 Merge pull request #773 from RockChinQ/feat/multi-modal
Feat: 多模态
2024-05-16 21:13:15 +08:00
RockChinQ
bae86ac05c chore: 恢复版本号 2024-05-16 21:03:56 +08:00
RockChinQ
a3706bfe21 perf: 细节优化 2024-05-16 21:02:59 +08:00
RockChinQ
91e23b8c11 perf: 为图片base64函数添加lru 2024-05-16 20:52:17 +08:00
RockChinQ
37ef1c9fab feat: 删除oss相关代码 2024-05-16 20:32:30 +08:00
RockChinQ
6bc6f77af1 feat: 通过 base64 传输图片 2024-05-16 20:25:51 +08:00
RockChinQ
2c478ccc25 feat: 模型vision支持性参数 2024-05-16 20:11:54 +08:00
RockChinQ
404e5492a3 chore: 同步现有模型信息 2024-05-16 18:29:23 +08:00
RockChinQ
d5b5d667a5 feat: 模型视觉多模态支持 2024-05-15 21:40:18 +08:00
RockChinQ
8807f02f36 perf: resp_message_chain 改为 list 类型 (#770) 2024-05-14 23:08:49 +08:00
RockChinQ
269e561497 perf: messages 存回 conversation 应该仅在成功执行本次请求时执行 (#769) 2024-05-14 22:41:39 +08:00
RockChinQ
527ad81d38 feat: 解藕chat的处理器和请求器 (#772) 2024-05-14 22:20:31 +08:00
Junyan Qin
972d3c18af Update README.md 2024-05-08 21:49:45 +08:00
Junyan Qin
3cbfc078fc doc(README.md): 更新 社区四群群号 2024-05-08 21:46:19 +08:00
61 changed files with 1088 additions and 399 deletions

View File

@@ -19,8 +19,8 @@
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=66-aWvn8cbP4c1ut_1YYkvvGVeEtyTH8&authKey=pTaKBK5C%2B8dFzQ4XlENf6MHTCLaHnlKcCRx7c14EeVVlpX2nRSaS8lJm8YeM4mCU&noverify=0&group_code=195992197"> <a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=66-aWvn8cbP4c1ut_1YYkvvGVeEtyTH8&authKey=pTaKBK5C%2B8dFzQ4XlENf6MHTCLaHnlKcCRx7c14EeVVlpX2nRSaS8lJm8YeM4mCU&noverify=0&group_code=195992197">
<img alt="Static Badge" src="https://img.shields.io/badge/%E5%AE%98%E6%96%B9%E7%BE%A4-195992197-purple"> <img alt="Static Badge" src="https://img.shields.io/badge/%E5%AE%98%E6%96%B9%E7%BE%A4-195992197-purple">
</a> </a>
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=nC80H57wmKPwRDLFeQrDDjVl81XuC21P&authKey=2wTUTfoQ5v%2BD4C5zfpuR%2BSPMDqdXgDXA%2FS2wHI1NxTfWIG%2B%2FqK08dgyjMMOzhXa9&noverify=0&group_code=248432104"> <a href="https://qm.qq.com/q/1yxEaIgXMA">
<img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-248432104-purple"> <img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-619154800-purple">
</a> </a>
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" > <a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/> <img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
@@ -39,7 +39,17 @@
<a href="https://github.com/RockChinQ/qcg-installer">安装器源码</a> <a href="https://github.com/RockChinQ/qcg-installer">安装器源码</a>
<a href="https://github.com/RockChinQ/qcg-tester">测试工程源码</a> <a href="https://github.com/RockChinQ/qcg-tester">测试工程源码</a>
<a href="https://github.com/RockChinQ/qcg-center">遥测服务端源码</a>
<a href="https://github.com/the-lazy-me/QChatGPT-Wiki">官方文档储存库</a> <a href="https://github.com/the-lazy-me/QChatGPT-Wiki">官方文档储存库</a>
<img alt="回复效果(带有联网插件)" src="https://qchatgpt.rockchin.top/assets/image/QChatGPT-1211.png" width="500px"/> <hr/>
<div align="center">
京东云4090单卡15C90G实例 <br/>
仅需1.89/小时包月1225元起 <br/>
可选预装Stable Diffusion等应用随用随停计费透明欢迎首选支持 <br/>
https://3.cn/1ZOi6-Gj
</div>
<img alt="回复效果(带有联网插件)" src="https://qchatgpt.rockchin.top/assets/image/QChatGPT-0516.png" width="500px"/>
</div> </div>

View File

@@ -9,8 +9,6 @@ from .groups import plugin
from ...core import app from ...core import app
BACKEND_URL = "https://api.qchatgpt.rockchin.top/api/v2"
class V2CenterAPI: class V2CenterAPI:
"""中央服务器 v2 API 交互类""" """中央服务器 v2 API 交互类"""
@@ -23,7 +21,7 @@ class V2CenterAPI:
plugin: plugin.V2PluginDataAPI = None plugin: plugin.V2PluginDataAPI = None
"""插件 API 组""" """插件 API 组"""
def __init__(self, ap: app.Application, basic_info: dict = None, runtime_info: dict = None): def __init__(self, ap: app.Application, backend_url: str, basic_info: dict = None, runtime_info: dict = None):
"""初始化""" """初始化"""
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info) logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info)
@@ -31,7 +29,7 @@ class V2CenterAPI:
apigroup.APIGroup._basic_info = basic_info apigroup.APIGroup._basic_info = basic_info
apigroup.APIGroup._runtime_info = runtime_info apigroup.APIGroup._runtime_info = runtime_info
self.main = main.V2MainDataAPI(BACKEND_URL, ap) self.main = main.V2MainDataAPI(backend_url, ap)
self.usage = usage.V2UsageDataAPI(BACKEND_URL, ap) self.usage = usage.V2UsageDataAPI(backend_url, ap)
self.plugin = plugin.V2PluginDataAPI(BACKEND_URL, ap) self.plugin = plugin.V2PluginDataAPI(backend_url, ap)

View File

@@ -13,11 +13,16 @@ class CommandReturn(pydantic.BaseModel):
"""命令返回值 """命令返回值
""" """
text: typing.Optional[str] text: typing.Optional[str] = None
"""文本 """文本
""" """
image: typing.Optional[mirai.Image] image: typing.Optional[mirai.Image] = None
"""弃用"""
image_url: typing.Optional[str] = None
"""图片链接
"""
error: typing.Optional[errors.CommandError]= None error: typing.Optional[errors.CommandError]= None
"""错误 """错误

View File

@@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator):
content = "" content = ""
for msg in prompt.messages: for msg in prompt.messages:
content += f" {msg.role}: {msg.content}" content += f" {msg.readable_str()}\n"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
@@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator):
context: entities.ExecuteContext context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else: else:
prompt_name = context.crt_params[0] prompt_name = context.crt_params[0]
try: try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None: if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else: else:
context.session.use_prompt_name = prompt.name context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator):
context.session.using_conversation = context.session.conversations[index-1] context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -42,7 +42,7 @@ class ListOperator(operator.CommandOperator):
using_conv_index = index using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page: if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n" content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n"
index += 1 index += 1
if content == '': if content == '':
@@ -51,6 +51,6 @@ class ListOperator(operator.CommandOperator):
if context.session.using_conversation is None: if context.session.using_conversation is None:
content += "\n当前处于新会话" content += "\n当前处于新会话"
else: else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}" content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}"
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}") yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}")

View File

@@ -27,7 +27,7 @@ class JSONConfigFile(file_model.ConfigFile):
else: else:
raise ValueError("template_file_name or template_data must be provided") raise ValueError("template_file_name or template_data must be provided")
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
if not self.exists(): if not self.exists():
await self.create() await self.create()
@@ -37,11 +37,16 @@ class JSONConfigFile(file_model.ConfigFile):
self.template_data = json.load(f) self.template_data = json.load(f)
with open(self.config_file_name, "r", encoding="utf-8") as f: with open(self.config_file_name, "r", encoding="utf-8") as f:
cfg = json.load(f) try:
cfg = json.load(f)
except json.JSONDecodeError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
for key in self.template_data: if completion:
if key not in cfg:
cfg[key] = self.template_data[key] for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
return cfg return cfg

View File

@@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def create(self): async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name) shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
@@ -43,18 +43,19 @@ class PythonModuleConfigFile(file_model.ConfigFile):
cfg[key] = getattr(module, key) cfg[key] = getattr(module, key)
# 从模板模块文件中进行补全 # 从模板模块文件中进行补全
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] if completion:
module = importlib.import_module(module_name) module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
module = importlib.import_module(module_name)
for key in dir(module): for key in dir(module):
if key.startswith('__'): if key.startswith('__'):
continue continue
if not isinstance(getattr(module, key), allowed_types): if not isinstance(getattr(module, key), allowed_types):
continue continue
if key not in cfg: if key not in cfg:
cfg[key] = getattr(module, key) cfg[key] = getattr(module, key)
return cfg return cfg

View File

@@ -20,8 +20,8 @@ class ConfigManager:
self.file = cfg_file self.file = cfg_file
self.data = {} self.data = {}
async def load_config(self): async def load_config(self, completion: bool=True):
self.data = await self.file.load() self.data = await self.file.load(completion=completion)
async def dump_config(self): async def dump_config(self):
await self.file.save(self.data) await self.file.save(self.data)
@@ -30,7 +30,7 @@ class ConfigManager:
self.file.save_sync(self.data) self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
"""加载Python模块配置文件""" """加载Python模块配置文件"""
cfg_inst = pymodule.PythonModuleConfigFile( cfg_inst = pymodule.PythonModuleConfigFile(
config_name, config_name,
@@ -38,12 +38,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
) )
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config() await cfg_mgr.load_config(completion=completion)
return cfg_mgr return cfg_mgr
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager: async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
"""加载JSON配置文件""" """加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile( cfg_inst = json_file.JSONConfigFile(
config_name, config_name,
@@ -52,6 +52,6 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
) )
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config() await cfg_mgr.load_config(completion=completion)
return cfg_mgr return cfg_mgr

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("vision-config", 6)
class VisionConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "enable-vision" not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
if "enable-vision" not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data["enable-vision"] = False
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("qcg-center-url-config", 7)
class QCGCenterURLConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "qcg-center-url" not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""
if "qcg-center-url" not in self.ap.system_cfg.data:
self.ap.system_cfg.data["qcg-center-url"] = "https://api.qchatgpt.rockchin.top/api/v2"
await self.ap.system_cfg.dump_config()

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("ad-fixwin-cfg-migration", 8)
class AdFixwinConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return isinstance(
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]["default"],
int
)
async def run(self):
"""执行迁移"""
for session_name in self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]:
temp_dict = {
"window-size": 60,
"limit": self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name]
}
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] = temp_dict
await self.ap.pipeline_cfg.dump_config()

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("msg-truncator-cfg-migration", 9)
class MsgTruncatorConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'msg-truncate' not in self.ap.pipeline_cfg.data
async def run(self):
"""执行迁移"""
self.ap.pipeline_cfg.data['msg-truncate'] = {
'method': 'round',
'round': {
'max-round': 10
}
}
await self.ap.pipeline_cfg.dump_config()

View File

@@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
pass pass
@abc.abstractmethod @abc.abstractmethod

View File

@@ -1,5 +1,7 @@
from __future__ import print_function from __future__ import print_function
import traceback
from . import app from . import app
from ..audit import identifier from ..audit import identifier
from . import stage from . import stage
@@ -27,6 +29,7 @@ async def make_app() -> app.Application:
for stage_name in stage_order: for stage_name in stage_order:
stage_cls = stage.preregistered_stages[stage_name] stage_cls = stage.preregistered_stages[stage_name]
stage_inst = stage_cls() stage_inst = stage_cls()
await stage_inst.run(ap) await stage_inst.run(ap)
await ap.initialize() await ap.initialize()
@@ -35,5 +38,8 @@ async def make_app() -> app.Application:
async def main(): async def main():
app_inst = await make_app() try:
await app_inst.run() app_inst = await make_app()
await app_inst.run()
except Exception as e:
traceback.print_exc()

View File

@@ -8,16 +8,3 @@ from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config load_python_module_config = config_mgr.load_python_module_config
load_json_config = config_mgr.load_json_config load_json_config = config_mgr.load_json_config
async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]:
override_json = json.load(open("override.json", "r", encoding="utf-8"))
overrided = []
config = cfg_mgr.data
for key in override_json:
if key in config:
config[key] = override_json[key]
overrided.append(key)
return overrided

View File

@@ -14,6 +14,7 @@ required_deps = {
"yaml": "pyyaml", "yaml": "pyyaml",
"aiohttp": "aiohttp", "aiohttp": "aiohttp",
"psutil": "psutil", "psutil": "psutil",
"async_lru": "async-lru",
} }

View File

@@ -67,10 +67,10 @@ class Query(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置""" """使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] = [] resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = []
"""由Process阶段生成的回复消息对象列表""" """由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None
"""回复消息链从resp_messages包装而得""" """回复消息链从resp_messages包装而得"""
class Config: class Config:

View File

@@ -15,7 +15,6 @@ from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr from ...platform import manager as im_mgr
@stage.stage_class("BuildAppStage") @stage.stage_class("BuildAppStage")
class BuildAppStage(stage.BootingStage): class BuildAppStage(stage.BootingStage):
"""构建应用阶段 """构建应用阶段
@@ -35,6 +34,7 @@ class BuildAppStage(stage.BootingStage):
center_v2_api = center_v2.V2CenterAPI( center_v2_api = center_v2.V2CenterAPI(
ap, ap,
backend_url=ap.system_cfg.data["qcg-center-url"],
basic_info={ basic_info={
"host_id": identifier.identifier["host_id"], "host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"], "instance_id": identifier.identifier["instance_id"],
@@ -83,7 +83,6 @@ class BuildAppStage(stage.BootingStage):
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize() await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap) im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize() await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst ap.platform_mgr = im_mgr_inst
@@ -92,5 +91,6 @@ class BuildAppStage(stage.BootingStage):
await stage_mgr.initialize() await stage_mgr.initialize()
ap.stage_mgr = stage_mgr ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap) ctrl = controller.Controller(ap)
ap.ctrl = ctrl ap.ctrl = ctrl

View File

@@ -12,11 +12,11 @@ class LoadConfigStage(stage.BootingStage):
async def run(self, ap: app.Application): async def run(self, ap: app.Application):
"""启动 """启动
""" """
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json") ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False)
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json") ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False)
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json") ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False)
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json") ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False)
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json") ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False)
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json") ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
await ap.plugin_setting_meta.dump_config() await ap.plugin_setting_meta.dump_config()

View File

@@ -5,7 +5,7 @@ import importlib
from .. import stage, app from .. import stage, app
from ...config import migration from ...config import migration
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ...config.migrations import m005_deepseek_cfg_completion from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
@stage.stage_class("MigrationStage") @stage.stage_class("MigrationStage")

View File

@@ -8,7 +8,10 @@ from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage') @stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage): class BanSessionCheckStage(stage.PipelineStage):
"""访问控制处理阶段""" """访问控制处理阶段
仅检查query中群号或个人号是否在访问控制列表中。
"""
async def initialize(self): async def initialize(self):
pass pass

View File

@@ -9,12 +9,24 @@ from ...core import entities as core_entities
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities
@stage.stage_class('PostContentFilterStage') @stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage') @stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage): class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段""" """内容过滤阶段
前置:
检查消息是否符合规则,不符合则拦截。
改写:
message_chain
后置:
检查AI回复消息是否符合规则可能进行改写不符合则拦截。
改写:
query.resp_messages
"""
filter_chain: list[filter_model.ContentFilter] filter_chain: list[filter_model.ContentFilter]
@@ -130,19 +142,34 @@ class ContentFilterStage(stage.PipelineStage):
"""处理 """处理
""" """
if stage_inst_name == 'PreContentFilterStage': if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
for me in query.message_chain:
if not isinstance(me, mirai.Plain):
contain_non_text = True
break
if contain_non_text:
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
return await self._pre_process( return await self._pre_process(
str(query.message_chain).strip(), str(query.message_chain).strip(),
query query
) )
elif stage_inst_name == 'PostContentFilterStage': elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况 # 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1].content, str): if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str):
return await self._post_process( return await self._post_process(
query.resp_messages[-1].content, query.resp_messages[-1].content,
query query
) )
else: else:
self.ap.logger.debug(f"resp_messages[-1] 不是 str 类型,跳过内容过滤器检查。") self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。")
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query

View File

@@ -4,6 +4,8 @@ import enum
import pydantic import pydantic
from ...provider import entities as llm_entities
class ResultLevel(enum.Enum): class ResultLevel(enum.Enum):
"""结果等级""" """结果等级"""
@@ -38,7 +40,7 @@ class FilterResult(pydantic.BaseModel):
""" """
replacement: str replacement: str
"""替换后的消息 """替换后的文本消息
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。 内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
若没有修改内容,也需要返回原消息。 若没有修改内容,也需要返回原消息。

View File

@@ -5,6 +5,7 @@ import typing
from ...core import app from ...core import app
from . import entities from . import entities
from ...provider import entities as llm_entities
preregistered_filters: list[typing.Type[ContentFilter]] = [] preregistered_filters: list[typing.Type[ContentFilter]] = []
@@ -63,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult: async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
"""处理消息 """处理消息
分为前后阶段,具体取决于 enable_stages 的值。 分为前后阶段,具体取决于 enable_stages 的值。
@@ -71,6 +72,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
Args: Args:
message (str): 需要检查的内容 message (str): 需要检查的内容
image_url (str): 要检查的图片的 URL
Returns: Returns:
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档 entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档

View File

@@ -8,7 +8,7 @@ from ....config import manager as cfg_mgr
@filter_model.filter_class("ban-word-filter") @filter_model.filter_class("ban-word-filter")
class BanWordFilter(filter_model.ContentFilter): class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言""" """根据内容过滤"""
async def initialize(self): async def initialize(self):
pass pass

View File

@@ -16,6 +16,9 @@ from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage") @stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage): class LongTextProcessStage(stage.PipelineStage):
"""长消息处理阶段 """长消息处理阶段
改写:
- resp_message_chain
""" """
strategy_impl: strategy.LongTextStrategy strategy_impl: strategy.LongTextStrategy
@@ -31,18 +34,18 @@ class LongTextProcessStage(stage.PipelineStage):
if os.name == "nt": if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc" use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font): if not os.path.exists(use_font):
self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。") self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在配置文件中调整相关设置。")
config['blob_message_strategy'] = "forward" config['blob_message_strategy'] = "forward"
else: else:
self.ap.logger.info("使用Windows自带字体" + use_font) self.ap.logger.info("使用Windows自带字体" + use_font)
config['font-path'] = use_font config['font-path'] = use_font
else: else:
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。") self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
except: except:
traceback.print_exc() traceback.print_exc()
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font)) self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
@@ -59,15 +62,15 @@ class LongTextProcessStage(stage.PipelineStage):
# 检查是否包含非 Plain 组件 # 检查是否包含非 Plain 组件
contains_non_plain = False contains_non_plain = False
for msg in query.resp_message_chain: for msg in query.resp_message_chain[-1]:
if not isinstance(msg, Plain): if not isinstance(msg, Plain):
contains_non_plain = True contains_non_plain = True
break break
if contains_non_plain: if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query)) query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from . import truncator
from .truncators import round
@stage.stage_class("ConversationMessageTruncator")
class ConversationMessageTruncator(stage.PipelineStage):
"""会话消息截断器
用于截断会话消息链,以适应平台消息长度限制。
"""
trun: truncator.Truncator
async def initialize(self):
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
for trun in truncator.preregistered_truncators:
if trun.name == use_method:
self.trun = trun(self.ap)
break
else:
raise ValueError(f"未知的截断器: {use_method}")
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import typing
import abc
from ...core import entities as core_entities, app
preregistered_truncators: list[typing.Type[Truncator]] = []
def truncator_class(
name: str
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
"""截断器类装饰器
Args:
name (str): 截断器名称
Returns:
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
"""
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
assert issubclass(cls, Truncator)
cls.name = name
preregistered_truncators.append(cls)
return cls
return decorator
class Truncator(abc.ABC):
"""消息截断器基类
"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。
请勿操作其他字段。
"""
pass

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from .. import truncator
from ....core import entities as core_entities
@truncator.truncator_class("round")
class RoundTruncator(truncator.Truncator):
"""前文回合数阶段器
"""
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
"""
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
temp_messages = []
current_round = 0
# 从后往前遍历
for msg in query.messages[::-1]:
if current_round < max_round:
temp_messages.append(msg)
if msg.role == 'user':
current_round += 1
else:
break
query.messages = temp_messages[::-1]
return query

View File

@@ -43,7 +43,7 @@ class QueryPool:
message_event=message_event, message_event=message_event,
message_chain=message_chain, message_chain=message_chain,
resp_messages=[], resp_messages=[],
resp_message_chain=None, resp_message_chain=[],
adapter=adapter adapter=adapter
) )
self.queries.append(query) self.queries.append(query)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import mirai
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...provider import entities as llm_entities from ...provider import entities as llm_entities
@@ -9,6 +11,16 @@ from ...plugin import events
@stage.stage_class("PreProcessor") @stage.stage_class("PreProcessor")
class PreProcessor(stage.PipelineStage): class PreProcessor(stage.PipelineStage):
"""请求预处理阶段 """请求预处理阶段
签出会话、prompt、上文、模型、内容函数。
改写:
- session
- prompt
- messages
- user_message
- use_model
- use_funcs
""" """
async def process( async def process(
@@ -27,21 +39,42 @@ class PreProcessor(stage.PipelineStage):
query.prompt = conversation.prompt.copy() query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy() query.messages = conversation.messages.copy()
query.user_message = llm_entities.Message(
role='user',
content=str(query.message_chain).strip()
)
query.use_model = conversation.use_model query.use_model = conversation.use_model
query.use_funcs = conversation.use_funcs query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
# 检查vision是否启用没启用就删除所有图片
if not self.ap.provider_cfg.data['enable-vision'] or not query.use_model.vision_supported:
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
if me.type == 'image_url':
msg.content.remove(me)
content_list = []
for me in query.message_chain:
if isinstance(me, mirai.Plain):
content_list.append(
llm_entities.ContentElement.from_text(me.text)
)
elif isinstance(me, mirai.Image):
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
if me.url is not None:
content_list.append(
llm_entities.ContentElement.from_image_url(str(me.url))
)
query.user_message = llm_entities.Message( # TODO 适配多模态输入
role='user',
content=content_list
)
# =========== 触发事件 PromptPreProcessing # =========== 触发事件 PromptPreProcessing
session = query.session
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing( event=events.PromptPreProcessing(
session_name=f'{session.launcher_type.value}_{session.launcher_id}', session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages, default_prompt=query.prompt.messages,
prompt=query.messages, prompt=query.messages,
query=query query=query

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import typing import typing
import time import time
import traceback import traceback
import json
import mirai import mirai
@@ -41,12 +42,7 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply) mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append( query.resp_messages.append(mc)
llm_entities.Message(
role='plugin',
content=mc,
)
)
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -70,17 +66,13 @@ class ChatMessageHandler(handler.MessageHandler):
mirai.Plain(event_ctx.event.alter) mirai.Plain(event_ctx.event.alter)
]) ])
query.messages.append(
query.user_message
)
text_length = 0 text_length = 0
start_time = time.time() start_time = time.time()
try: try:
async for result in query.use_model.requester.request(query): async for result in self.runner(query):
query.resp_messages.append(result) query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
@@ -92,6 +84,9 @@ class ChatMessageHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e: except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}') self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
@@ -104,8 +99,6 @@ class ChatMessageHandler(handler.MessageHandler):
debug_notice=traceback.format_exc() debug_notice=traceback.format_exc()
) )
finally: finally:
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
await self.ap.ctr_mgr.usage.post_query_record( await self.ap.ctr_mgr.usage.post_query_record(
session_type=query.session.launcher_type.value, session_type=query.session.launcher_type.value,
@@ -115,4 +108,65 @@ class ChatMessageHandler(handler.MessageHandler):
model_name=query.use_model.name, model_name=query.use_model.name,
response_seconds=int(time.time() - start_time), response_seconds=int(time.time() - start_time),
retry_times=-1, retry_times=-1,
) )
async def runner(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
这是临时处理方案后续可能改为使用LangChain或者自研的工作流处理器
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)

View File

@@ -48,12 +48,7 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply) mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append( query.resp_messages.append(mc)
llm_entities.Message(
role='command',
content=str(mc),
)
)
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -80,9 +75,6 @@ class CommandHandler(handler.MessageHandler):
session=session session=session
): ):
if ret.error is not None: if ret.error is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(str(ret.error))
# ])
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( llm_entities.Message(
role='command', role='command',
@@ -96,18 +88,28 @@ class CommandHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
elif ret.text is not None: elif ret.text is not None or ret.image_url is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(ret.text) content: list[llm_entities.ContentElement]= []
# ])
if ret.text is not None:
content.append(
llm_entities.ContentElement.from_text(ret.text)
)
if ret.image_url is not None:
content.append(
llm_entities.ContentElement.from_image_url(ret.image_url)
)
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( llm_entities.Message(
role='command', role='command',
content=ret.text, content=content,
) )
) )
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}') self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

@@ -11,7 +11,13 @@ from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor") @stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage): class Processor(stage.PipelineStage):
"""请求实际处理阶段""" """请求实际处理阶段
通过命令处理器和聊天处理器处理消息。
改写:
- resp_messages
"""
cmd_handler: handler.MessageHandler cmd_handler: handler.MessageHandler

View File

@@ -1,18 +1,15 @@
# 固定窗口算法
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import time import time
from .. import algo from .. import algo
# 固定窗口算法
class SessionContainer: class SessionContainer:
wait_lock: asyncio.Lock wait_lock: asyncio.Lock
records: dict[int, int] records: dict[int, int]
"""访问记录key为每分钟的起始时间戳value为访问次数""" """访问记录key为每窗口长度的起始时间戳value为访问次数"""
def __init__(self): def __init__(self):
self.wait_lock = asyncio.Lock() self.wait_lock = asyncio.Lock()
@@ -47,30 +44,34 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 等待锁 # 等待锁
async with container.wait_lock: async with container.wait_lock:
# 获取窗口大小和限制
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit']
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
# 获取当前时间戳 # 获取当前时间戳
now = int(time.time()) now = int(time.time())
# 获取当前分钟的起始时间戳 # 获取当前窗口的起始时间戳
now = now - now % 60 now = now - now % window_size
# 获取当前分钟的访问次数 # 获取当前窗口的访问次数
count = container.records.get(now, 0) count = container.records.get(now, 0)
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]
# 如果访问次数超过了限制 # 如果访问次数超过了限制
if count >= limitation: if count >= limitation:
if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop': if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop':
return False return False
elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait': elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait':
# 等待下一分钟 # 等待下一窗口
await asyncio.sleep(60 - time.time() % 60) await asyncio.sleep(window_size - time.time() % window_size)
now = int(time.time()) now = int(time.time())
now = now - now % 60 now = now - now % window_size
if now not in container.records: if now not in container.records:
container.records = {} container.records = {}

View File

@@ -11,7 +11,10 @@ from ...core import entities as core_entities
@stage.stage_class("RequireRateLimitOccupancy") @stage.stage_class("RequireRateLimitOccupancy")
@stage.stage_class("ReleaseRateLimitOccupancy") @stage.stage_class("ReleaseRateLimitOccupancy")
class RateLimit(stage.PipelineStage): class RateLimit(stage.PipelineStage):
"""限速器控制阶段""" """限速器控制阶段
不改写query只检查是否需要限速。
"""
algo: algo.ReteLimitAlgo algo: algo.ReteLimitAlgo

View File

@@ -31,7 +31,7 @@ class SendResponseBackStage(stage.PipelineStage):
await self.ap.platform_mgr.send( await self.ap.platform_mgr.send(
query.message_event, query.message_event,
query.resp_message_chain, query.resp_message_chain[-1],
adapter=query.adapter adapter=query.adapter
) )

View File

@@ -14,9 +14,12 @@ from ...config import manager as cfg_mgr
@stage.stage_class("GroupRespondRuleCheckStage") @stage.stage_class("GroupRespondRuleCheckStage")
class GroupRespondRuleCheckStage(stage.PipelineStage): class GroupRespondRuleCheckStage(stage.PipelineStage):
"""群组响应规则检查器 """群组响应规则检查器
仅检查群消息是否符合规则。
""" """
rule_matchers: list[rule.GroupRespondRule] rule_matchers: list[rule.GroupRespondRule]
"""检查器实例"""
async def initialize(self): async def initialize(self):
"""初始化检查器 """初始化检查器
@@ -31,7 +34,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query

View File

@@ -13,21 +13,23 @@ from .respback import respback
from .wrapper import wrapper from .wrapper import wrapper
from .preproc import preproc from .preproc import preproc
from .ratelimit import ratelimit from .ratelimit import ratelimit
from .msgtrun import msgtrun
# 请求处理阶段顺序 # 请求处理阶段顺序
stage_order = [ stage_order = [
"GroupRespondRuleCheckStage", "GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", "BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", "PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", "PreProcessor", # 预处理器
"RequireRateLimitOccupancy", "ConversationMessageTruncator", # 会话消息截断器
"MessageProcessor", "RequireRateLimitOccupancy", # 请求速率限制占用
"ReleaseRateLimitOccupancy", "MessageProcessor", # 处理器
"PostContentFilterStage", "ReleaseRateLimitOccupancy", # 释放速率限制占用
"ResponseWrapper", "PostContentFilterStage", # 内容过滤后置阶段
"LongTextProcessStage", "ResponseWrapper", # 响应包装器
"SendResponseBackStage", "LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
] ]

View File

@@ -14,6 +14,13 @@ from ...plugin import events
@stage.stage_class("ResponseWrapper") @stage.stage_class("ResponseWrapper")
class ResponseWrapper(stage.PipelineStage): class ResponseWrapper(stage.PipelineStage):
"""回复包装阶段
把回复的 message 包装成人类识读的形式。
改写:
- resp_message_chain
"""
async def initialize(self): async def initialize(self):
pass pass
@@ -25,78 +32,49 @@ class ResponseWrapper(stage.PipelineStage):
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理 """处理
""" """
if query.resp_messages[-1].role == 'command': # 如果 resp_messages[-1] 已经是 MessageChain 了
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) if isinstance(query.resp_messages[-1], mirai.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
elif query.resp_messages[-1].role == 'plugin':
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
else:
query.resp_message_chain = query.resp_messages[-1].content
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else: else:
if query.resp_messages[-1].role == 'command':
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
if query.resp_messages[-1].role == 'assistant': yield entities.StageProcessResult(
result = query.resp_messages[-1] result_type=entities.ResultType.CONTINUE,
session = await self.ap.sess_mgr.get_session(query) new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
# else:
# query.resp_message_chain.append(query.resp_messages[-1].content)
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
reply_text = '' yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
if result.content is not None: # 有内容 if query.resp_messages[-1].role == 'assistant':
reply_text = result.content result = query.resp_messages[-1]
session = await self.ap.sess_mgr.get_session(query)
# ============= 触发插件事件 =============== reply_text = ''
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else: if result.content: # 有内容
reply_text = str(result.get_content_mirai_message_chain())
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) # ============= 触发插件事件 ===============
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.platform_cfg.data['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded( event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value, launcher_type=query.launcher_type.value,
@@ -110,7 +88,6 @@ class ResponseWrapper(stage.PipelineStage):
query=query query=query
) )
) )
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, result_type=entities.ResultType.INTERRUPT,
@@ -119,13 +96,56 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else: else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) query.resp_message_chain.append(result.get_content_mirai_message_chain())
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -31,13 +31,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
msg_time = msg.time msg_time = msg.time
elif type(msg) is mirai.Image: elif type(msg) is mirai.Image:
arg = '' arg = ''
if msg.base64:
if msg.url: arg = msg.base64
msg_list.append(aiocqhttp.MessageSegment.image(f"base64://{arg}"))
elif msg.url:
arg = msg.url arg = msg.url
msg_list.append(aiocqhttp.MessageSegment.image(arg))
elif msg.path: elif msg.path:
arg = msg.path arg = msg.path
msg_list.append(aiocqhttp.MessageSegment.image(arg))
msg_list.append(aiocqhttp.MessageSegment.image(arg))
elif type(msg) is mirai.At: elif type(msg) is mirai.At:
msg_list.append(aiocqhttp.MessageSegment.at(msg.target)) msg_list.append(aiocqhttp.MessageSegment.at(msg.target))
elif type(msg) is mirai.AtAll: elif type(msg) is mirai.AtAll:

View File

@@ -322,7 +322,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
proxies=None proxies=None
) )
if resp.status_code == 403: if resp.status_code == 403:
raise Exception("go-cqhttp拒绝访问请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配") raise Exception("go-cqhttp拒绝访问请检查配置文件中nakuru适配器的配置")
self.bot_account_id = int(resp.json()['data']['user_id']) self.bot_account_id = int(resp.json()['data']['user_id'])
except Exception as e: except Exception as e:
raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确") raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确")

View File

@@ -21,6 +21,39 @@ class ToolCall(pydantic.BaseModel):
function: FunctionCall function: FunctionCall
class ImageURLContentObject(pydantic.BaseModel):
url: str
def __str__(self):
return self.url[:128] + ('...' if len(self.url) > 128 else '')
class ContentElement(pydantic.BaseModel):
type: str
"""内容类型"""
text: typing.Optional[str] = None
image_url: typing.Optional[ImageURLContentObject] = None
def __str__(self):
if self.type == 'text':
return self.text
elif self.type == 'image_url':
return f'[图片]({self.image_url})'
else:
return '未知内容'
@classmethod
def from_text(cls, text: str):
return cls(type='text', text=text)
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
class Message(pydantic.BaseModel): class Message(pydantic.BaseModel):
"""消息""" """消息"""
@@ -30,12 +63,9 @@ class Message(pydantic.BaseModel):
name: typing.Optional[str] = None name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置""" """名称,仅函数调用返回时设置"""
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
"""内容""" """内容"""
function_call: typing.Optional[FunctionCall] = None
"""函数调用不再受支持请使用tool_calls"""
tool_calls: typing.Optional[list[ToolCall]] = None tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用""" """工具调用"""
@@ -43,10 +73,47 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str: def readable_str(self) -> str:
if self.content is not None: if self.content is not None:
return str(self.content) return str(self.role) + ": " + str(self.get_content_mirai_message_chain())
elif self.function_call is not None:
return f'{self.function_call.name}({self.function_call.arguments})'
elif self.tool_calls is not None: elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}' return f'调用工具: {self.tool_calls[0].id}'
else: else:
return '未知消息' return '未知消息'
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
"""将内容转换为 Mirai MessageChain 对象
Args:
prefix_text (str): 首个文字组件的前缀文本
"""
if self.content is None:
return None
elif isinstance(self.content, str):
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
elif isinstance(self.content, list):
mc = []
for ce in self.content:
if ce.type == 'text':
mc.append(mirai.Plain(ce.text))
elif ce.type == 'image':
if ce.image_url.url.startswith("http"):
mc.append(mirai.Image(url=ce.image_url.url))
else: # base64
b64_str = ce.image_url.url
if b64_str.startswith("data:"):
b64_str = b64_str.split(",")[1]
mc.append(mirai.Image(base64=b64_str))
# 找第一个文字组件
if prefix_text:
for i, c in enumerate(mc):
if isinstance(c, mirai.Plain):
mc[i] = mirai.Plain(prefix_text+c.text)
break
else:
mc.insert(0, mirai.Plain(prefix_text))
return mirai.MessageChain(mc)

View File

@@ -6,6 +6,8 @@ import typing
from ...core import app from ...core import app
from ...core import entities as core_entities from ...core import entities as core_entities
from .. import entities as llm_entities from .. import entities as llm_entities
from . import entities as modelmgr_entities
from ..tools import entities as tools_entities
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = [] preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
@@ -33,20 +35,31 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def initialize(self): async def initialize(self):
pass pass
@abc.abstractmethod async def preprocess(
async def request(
self, self,
query: core_entities.Query, query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]: ):
"""请求API """预处理
在这里处理特定API对Query对象的兼容性问题。
"""
pass
对话前文可以从 query 对象中获取。 @abc.abstractmethod
可以多次yield消息对象。 async def call(
self,
model: modelmgr_entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
"""调用API
Args: Args:
query (core_entities.Query): 本次请求的上下文对象 model (modelmgr_entities.LLMModelInfo): 使用的模型信息
messages (typing.List[llm_entities.Message]): 消息对象列表
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
Yields: Returns:
pkg.provider.entities.Message: 返回消息对象 llm_entities.Message: 返回消息对象
""" """
raise NotImplementedError pass

View File

@@ -11,6 +11,7 @@ from .. import api, entities, errors
from ....core import entities as core_entities from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
from ...tools import entities as tools_entities from ...tools import entities as tools_entities
from ....utils import image
@api.requester_class("anthropic-messages") @api.requester_class("anthropic-messages")
@@ -27,47 +28,76 @@ class AnthropicMessages(api.LLMAPIRequester):
proxies=self.ap.proxy_mgr.get_forward_proxies() proxies=self.ap.proxy_mgr.get_forward_proxies()
) )
async def request( async def call(
self, self,
query: core_entities.Query, model: entities.LLMModelInfo,
) -> typing.AsyncGenerator[llm_entities.Message, None]: messages: typing.List[llm_entities.Message],
self.client.api_key = query.use_model.token_mgr.get_token() funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name args["model"] = model.name if model.model_name is None else model.model_name
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 # 处理消息
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
] + [m.dict(exclude_none=True) for m in query.messages]
# 删除所有 role=system & content='' 的消息 # system
req_messages = [ system_role_message = None
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]
# 检查是否有 role=system 的消息,若有,改为 role=user并在后面加一个 role=assistant 的消息 for i, m in enumerate(messages):
system_role_index = [] if m.role == "system":
for i, m in enumerate(req_messages): system_role_message = m
if m["role"] == "system":
system_role_index.append(i)
m["role"] = "user"
if system_role_index: messages.pop(i)
for i in system_role_index[::-1]: break
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
# 忽略掉空消息,用户可能发送空消息,而上层未过滤 if isinstance(system_role_message, llm_entities.Message) \
req_messages = [ and isinstance(system_role_message.content, str):
m for m in req_messages if m["content"].strip() != "" args['system'] = system_role_message.content
]
req_messages = []
for m in messages:
if isinstance(m.content, str) and m.content.strip() != "":
req_messages.append(m.dict(exclude_none=True))
elif isinstance(m.content, list):
# m.content = [
# c for c in m.content if c.type == "text"
# ]
# if len(m.content) > 0:
# req_messages.append(m.dict(exclude_none=True))
msg_dict = m.dict(exclude_none=True)
for i, ce in enumerate(m.content):
if ce.type == "image_url":
alter_image_ele = {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": await image.qq_image_url_to_base64(ce.image_url.url)
}
}
msg_dict["content"][i] = alter_image_ele
req_messages.append(msg_dict)
args["messages"] = req_messages args["messages"] = req_messages
try: # anthropic的tools处在beta阶段sdk不稳定故暂时不支持
#
# if funcs:
# tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
# if tools:
# args["tools"] = tools
try:
resp = await self.client.messages.create(**args) resp = await self.client.messages.create(**args)
yield llm_entities.Message( return llm_entities.Message(
content=resp.content[0].text, content=resp.content[0].text,
role=resp.role role=resp.role
) )
@@ -79,4 +109,4 @@ class AnthropicMessages(api.LLMAPIRequester):
if 'model: ' in str(e): if 'model: ' in str(e):
raise errors.RequesterError(f'模型无效: {e.message}') raise errors.RequesterError(f'模型无效: {e.message}')
else: else:
raise errors.RequesterError(f'请求地址无效: {e.message}') raise errors.RequesterError(f'请求地址无效: {e.message}')

View File

@@ -3,16 +3,20 @@ from __future__ import annotations
import asyncio import asyncio
import typing import typing
import json import json
import base64
from typing import AsyncGenerator from typing import AsyncGenerator
import openai import openai
import openai.types.chat.chat_completion as chat_completion import openai.types.chat.chat_completion as chat_completion
import httpx import httpx
import aiohttp
import async_lru
from .. import api, entities, errors from .. import api, entities, errors
from ....core import entities as core_entities, app from ....core import entities as core_entities, app
from ... import entities as llm_entities from ... import entities as llm_entities
from ...tools import entities as tools_entities from ...tools import entities as tools_entities
from ....utils import image
@api.requester_class("openai-chat-completions") @api.requester_class("openai-chat-completions")
@@ -43,7 +47,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
self, self,
args: dict, args: dict,
) -> chat_completion.ChatCompletion: ) -> chat_completion.ChatCompletion:
self.ap.logger.debug(f"req chat_completion with args {args}")
return await self.client.chat.completions.create(**args) return await self.client.chat.completions.create(**args)
async def _make_msg( async def _make_msg(
@@ -67,14 +70,22 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
args = self.requester_cfg['args'].copy() args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_model.tool_call_supported: if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools: if tools:
args["tools"] = tools args["tools"] = tools
# 设置此次请求中的messages # 设置此次请求中的messages
messages = req_messages messages = req_messages.copy()
# 检查vision
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_url":
me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])
args["messages"] = messages args["messages"] = messages
# 发送请求 # 发送请求
@@ -84,73 +95,19 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
message = await self._make_msg(resp) message = await self._make_msg(resp)
return message return message
async def _request( async def call(
self, query: core_entities.Query self,
) -> typing.AsyncGenerator[llm_entities.Message, None]: model: entities.LLMModelInfo,
"""请求""" messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
pending_tool_calls = [] ) -> llm_entities.Message:
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != "" m.dict(exclude_none=True) for m in messages
] + [m.dict(exclude_none=True) for m in query.messages] ]
# req_messages.append({"role": "user", "content": str(query.message_chain)})
# 首次请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg.dict(exclude_none=True))
except Exception as e:
# 出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(
err_msg.dict(exclude_none=True)
)
# 处理完所有调用,继续请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
try: try:
async for msg in self._request(query): return await self._closure(req_messages, model, funcs)
yield msg
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise errors.RequesterError('请求超时') raise errors.RequesterError('请求超时')
except openai.BadRequestError as e: except openai.BadRequestError as e:
@@ -163,6 +120,16 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
except openai.NotFoundError as e: except openai.NotFoundError as e:
raise errors.RequesterError(f'请求路径错误: {e.message}') raise errors.RequesterError(f'请求路径错误: {e.message}')
except openai.RateLimitError as e: except openai.RateLimitError as e:
raise errors.RequesterError(f'请求过于频繁: {e.message}') raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e: except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}') raise errors.RequesterError(f'请求错误: {e.message}')
@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image = await image.qq_image_url_to_base64(original_url)
return f"data:image/jpeg;base64,{base64_image}"

View File

@@ -3,7 +3,10 @@ from __future__ import annotations
from ....core import app from ....core import app
from . import chatcmpl from . import chatcmpl
from .. import api from .. import api, entities, errors
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("deepseek-chat-completions") @api.requester_class("deepseek-chat-completions")
@@ -12,4 +15,39 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions'] self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions']
self.ap = ap self.ap = ap
async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
args["messages"] = messages
# 发送请求
resp = await self._req(args)
# 处理请求结果
message = await self._make_msg(resp)
return message

View File

@@ -3,7 +3,10 @@ from __future__ import annotations
from ....core import app from ....core import app
from . import chatcmpl from . import chatcmpl
from .. import api from .. import api, entities, errors
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("moonshot-chat-completions") @api.requester_class("moonshot-chat-completions")
@@ -13,3 +16,41 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions'] self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
self.ap = ap self.ap = ap
async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
# 删除空的
messages = [m for m in messages if m["content"].strip() != ""]
args["messages"] = messages
# 发送请求
resp = await self._req(args)
# 处理请求结果
message = await self._make_msg(resp)
return message

View File

@@ -21,5 +21,7 @@ class LLMModelInfo(pydantic.BaseModel):
tool_call_supported: typing.Optional[bool] = False tool_call_supported: typing.Optional[bool] = False
vision_supported: typing.Optional[bool] = False
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@@ -37,7 +37,7 @@ class ModelManager:
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置") raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
async def initialize(self): async def initialize(self):
# 初始化token_mgr, requester # 初始化token_mgr, requester
for k, v in self.ap.provider_cfg.data['keys'].items(): for k, v in self.ap.provider_cfg.data['keys'].items():
self.token_mgrs[k] = token.TokenManager(k, v) self.token_mgrs[k] = token.TokenManager(k, v)
@@ -83,7 +83,8 @@ class ModelManager:
model_name=None, model_name=None,
token_mgr=self.token_mgrs[model['token_mgr']], token_mgr=self.token_mgrs[model['token_mgr']],
requester=self.requesters[model['requester']], requester=self.requesters[model['requester']],
tool_call_supported=model['tool_call_supported'] tool_call_supported=model['tool_call_supported'],
vision_supported=model['vision_supported']
) )
break break
@@ -95,13 +96,15 @@ class ModelManager:
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported) tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
vision_supported = model.get('vision_supported', default_model_info.vision_supported)
model_info = entities.LLMModelInfo( model_info = entities.LLMModelInfo(
name=model['name'], name=model['name'],
model_name=model_name, model_name=model_name,
token_mgr=token_mgr, token_mgr=token_mgr,
requester=requester, requester=requester,
tool_call_supported=tool_call_supported tool_call_supported=tool_call_supported,
vision_supported=vision_supported
) )
self.model_list.append(model_info) self.model_list.append(model_info)

View File

@@ -9,11 +9,10 @@ from ...plugin import context as plugin_context
class ToolManager: class ToolManager:
"""LLM工具管理器 """LLM工具管理器"""
"""
ap: app.Application ap: app.Application
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.ap = ap self.ap = ap
self.all_functions = [] self.all_functions = []
@@ -22,35 +21,33 @@ class ToolManager:
pass pass
async def get_function(self, name: str) -> entities.LLMFunction: async def get_function(self, name: str) -> entities.LLMFunction:
"""获取函数 """获取函数"""
"""
for function in await self.get_all_functions(): for function in await self.get_all_functions():
if function.name == name: if function.name == name:
return function return function
return None return None
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]: async def get_function_and_plugin(
"""获取函数和插件 self, name: str
""" ) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
"""获取函数和插件"""
for plugin in self.ap.plugin_mgr.plugins: for plugin in self.ap.plugin_mgr.plugins:
for function in plugin.content_functions: for function in plugin.content_functions:
if function.name == name: if function.name == name:
return function, plugin.plugin_inst return function, plugin.plugin_inst
return None, None return None, None
async def get_all_functions(self) -> list[entities.LLMFunction]: async def get_all_functions(self) -> list[entities.LLMFunction]:
"""获取所有函数 """获取所有函数"""
"""
all_functions: list[entities.LLMFunction] = [] all_functions: list[entities.LLMFunction] = []
for plugin in self.ap.plugin_mgr.plugins: for plugin in self.ap.plugin_mgr.plugins:
all_functions.extend(plugin.content_functions) all_functions.extend(plugin.content_functions)
return all_functions return all_functions
async def generate_tools_for_openai(self, use_funcs: entities.LLMFunction) -> str: async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list:
"""生成函数列表 """生成函数列表"""
"""
tools = [] tools = []
for function in use_funcs: for function in use_funcs:
@@ -60,40 +57,71 @@ class ToolManager:
"function": { "function": {
"name": function.name, "name": function.name,
"description": function.description, "description": function.description,
"parameters": function.parameters "parameters": function.parameters,
} },
}
tools.append(function_schema)
return tools
async def generate_tools_for_anthropic(
self, use_funcs: list[entities.LLMFunction]
) -> list:
"""为anthropic生成函数列表
e.g.
[
{
"name": "get_stock_price",
"description": "Get the current stock price for a given ticker symbol.",
"input_schema": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "The stock ticker symbol, e.g. AAPL for Apple Inc."
}
},
"required": ["ticker"]
}
}
]
"""
tools = []
for function in use_funcs:
if function.enable:
function_schema = {
"name": function.name,
"description": function.description,
"input_schema": function.parameters,
} }
tools.append(function_schema) tools.append(function_schema)
return tools return tools
async def execute_func_call( async def execute_func_call(
self, self, query: core_entities.Query, name: str, parameters: dict
query: core_entities.Query,
name: str,
parameters: dict
) -> typing.Any: ) -> typing.Any:
"""执行函数调用 """执行函数调用"""
"""
try: try:
function, plugin = await self.get_function_and_plugin(name) function, plugin = await self.get_function_and_plugin(name)
if function is None: if function is None:
return None return None
parameters = parameters.copy() parameters = parameters.copy()
parameters = { parameters = {"query": query, **parameters}
"query": query,
**parameters
}
return await function.func(plugin, **parameters) return await function.func(plugin, **parameters)
except Exception as e: except Exception as e:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}")
traceback.print_exc() traceback.print_exc()
return f'error occurred when executing function {name}: {e}' return f"error occurred when executing function {name}: {e}"
finally: finally:
plugin = None plugin = None
@@ -107,11 +135,11 @@ class ToolManager:
await self.ap.ctr_mgr.usage.post_function_record( await self.ap.ctr_mgr.usage.post_function_record(
plugin={ plugin={
'name': plugin.plugin_name, "name": plugin.plugin_name,
'remote': plugin.plugin_source, "remote": plugin.plugin_source,
'version': plugin.plugin_version, "version": plugin.plugin_version,
'author': plugin.plugin_author "author": plugin.plugin_author,
}, },
function_name=function.name, function_name=function.name,
function_description=function.description, function_description=function.description,
) )

View File

@@ -1 +1 @@
semantic_version = "v3.1.1" semantic_version = "v3.2.2"

41
pkg/utils/image.py Normal file
View File

@@ -0,0 +1,41 @@
import base64
import typing
from urllib.parse import urlparse, parse_qs
import ssl
import aiohttp
async def qq_image_url_to_base64(
image_url: str
) -> str:
"""将QQ图片URL转为base64
Args:
image_url (str): QQ图片URL
Returns:
str: base64编码
"""
parsed = urlparse(image_url)
query = parse_qs(parsed.query)
# Flatten the query dictionary
query = {k: v[0] for k, v in query.items()}
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(
f"http://{parsed.netloc}{parsed.path}",
params=query,
ssl=ssl_context
) as resp:
resp.raise_for_status() # 检查HTTP错误
file_bytes = await resp.read()
base64_str = base64.b64encode(file_bytes).decode()
return base64_str

View File

@@ -13,4 +13,5 @@ aiohttp
pydantic pydantic
websockets websockets
urllib3 urllib3
psutil psutil
async-lru

View File

@@ -4,23 +4,73 @@
"name": "default", "name": "default",
"requester": "openai-chat-completions", "requester": "openai-chat-completions",
"token_mgr": "openai", "token_mgr": "openai",
"tool_call_supported": false "tool_call_supported": false,
"vision_supported": false
},
{
"name": "gpt-3.5-turbo-0125",
"tool_call_supported": true,
"vision_supported": false
}, },
{ {
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": false
}, },
{ {
"name": "gpt-4", "name": "gpt-3.5-turbo-1106",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": false
},
{
"name": "gpt-4-turbo",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-turbo-2024-04-09",
"tool_call_supported": true,
"vision_supported": true
}, },
{ {
"name": "gpt-4-turbo-preview", "name": "gpt-4-turbo-preview",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-0125-preview",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-1106-preview",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4o",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-0613",
"tool_call_supported": true,
"vision_supported": true
}, },
{ {
"name": "gpt-4-32k", "name": "gpt-4-32k",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-32k-0613",
"tool_call_supported": true,
"vision_supported": true
}, },
{ {
"model_name": "SparkDesk", "model_name": "SparkDesk",
@@ -33,32 +83,38 @@
{ {
"name": "claude-3-opus-20240229", "name": "claude-3-opus-20240229",
"requester": "anthropic-messages", "requester": "anthropic-messages",
"token_mgr": "anthropic" "token_mgr": "anthropic",
"vision_supported": true
}, },
{ {
"name": "claude-3-sonnet-20240229", "name": "claude-3-sonnet-20240229",
"requester": "anthropic-messages", "requester": "anthropic-messages",
"token_mgr": "anthropic" "token_mgr": "anthropic",
"vision_supported": true
}, },
{ {
"name": "claude-3-haiku-20240307", "name": "claude-3-haiku-20240307",
"requester": "anthropic-messages", "requester": "anthropic-messages",
"token_mgr": "anthropic" "token_mgr": "anthropic",
"vision_supported": true
}, },
{ {
"name": "moonshot-v1-8k", "name": "moonshot-v1-8k",
"requester": "moonshot-chat-completions", "requester": "moonshot-chat-completions",
"token_mgr": "moonshot" "token_mgr": "moonshot",
"tool_call_supported": true
}, },
{ {
"name": "moonshot-v1-32k", "name": "moonshot-v1-32k",
"requester": "moonshot-chat-completions", "requester": "moonshot-chat-completions",
"token_mgr": "moonshot" "token_mgr": "moonshot",
"tool_call_supported": true
}, },
{ {
"name": "moonshot-v1-128k", "name": "moonshot-v1-128k",
"requester": "moonshot-chat-completions", "requester": "moonshot-chat-completions",
"token_mgr": "moonshot" "token_mgr": "moonshot",
"tool_call_supported": true
}, },
{ {
"name": "deepseek-chat", "name": "deepseek-chat",

View File

@@ -29,7 +29,16 @@
"strategy": "drop", "strategy": "drop",
"algo": "fixwin", "algo": "fixwin",
"fixwin": { "fixwin": {
"default": 60 "default": {
"window-size": 60,
"limit": 60
}
}
},
"msg-truncate": {
"method": "round",
"round": {
"max-round": 10
} }
} }
} }

View File

@@ -1,5 +1,6 @@
{ {
"enable-chat": true, "enable-chat": true,
"enable-vision": true,
"keys": { "keys": {
"openai": [ "openai": [
"sk-1234567890" "sk-1234567890"

View File

@@ -10,5 +10,6 @@
"default": 1 "default": 1
}, },
"pipeline-concurrency": 20, "pipeline-concurrency": 20,
"qcg-center-url": "https://api.qchatgpt.rockchin.top/api/v2",
"help-message": "QChatGPT - 😎高稳定性、🧩支持插件、🌏实时联网的 ChatGPT QQ 机器人🤖\n链接https://q.rkcn.top" "help-message": "QChatGPT - 😎高稳定性、🧩支持插件、🌏实时联网的 ChatGPT QQ 机器人🤖\n链接https://q.rkcn.top"
} }