mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-07 22:36:02 +00:00
feat: 模型视觉多模态支持
This commit is contained in:
@@ -13,11 +13,16 @@ class CommandReturn(pydantic.BaseModel):
|
||||
"""命令返回值
|
||||
"""
|
||||
|
||||
text: typing.Optional[str]
|
||||
text: typing.Optional[str] = None
|
||||
"""文本
|
||||
"""
|
||||
|
||||
image: typing.Optional[mirai.Image]
|
||||
image: typing.Optional[mirai.Image] = None
|
||||
"""弃用"""
|
||||
|
||||
image_url: typing.Optional[str] = None
|
||||
"""图片链接
|
||||
"""
|
||||
|
||||
error: typing.Optional[errors.CommandError]= None
|
||||
"""错误
|
||||
|
||||
@@ -27,7 +27,7 @@ class JSONConfigFile(file_model.ConfigFile):
|
||||
else:
|
||||
raise ValueError("template_file_name or template_data must be provided")
|
||||
|
||||
async def load(self) -> dict:
|
||||
async def load(self, completion: bool=True) -> dict:
|
||||
|
||||
if not self.exists():
|
||||
await self.create()
|
||||
@@ -39,9 +39,11 @@ class JSONConfigFile(file_model.ConfigFile):
|
||||
with open(self.config_file_name, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
for key in self.template_data:
|
||||
if key not in cfg:
|
||||
cfg[key] = self.template_data[key]
|
||||
if completion:
|
||||
|
||||
for key in self.template_data:
|
||||
if key not in cfg:
|
||||
cfg[key] = self.template_data[key]
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
|
||||
async def create(self):
|
||||
shutil.copyfile(self.template_file_name, self.config_file_name)
|
||||
|
||||
async def load(self) -> dict:
|
||||
async def load(self, completion: bool=True) -> dict:
|
||||
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
@@ -43,18 +43,19 @@ class PythonModuleConfigFile(file_model.ConfigFile):
|
||||
cfg[key] = getattr(module, key)
|
||||
|
||||
# 从模板模块文件中进行补全
|
||||
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
|
||||
module = importlib.import_module(module_name)
|
||||
if completion:
|
||||
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for key in dir(module):
|
||||
if key.startswith('__'):
|
||||
continue
|
||||
for key in dir(module):
|
||||
if key.startswith('__'):
|
||||
continue
|
||||
|
||||
if not isinstance(getattr(module, key), allowed_types):
|
||||
continue
|
||||
if not isinstance(getattr(module, key), allowed_types):
|
||||
continue
|
||||
|
||||
if key not in cfg:
|
||||
cfg[key] = getattr(module, key)
|
||||
if key not in cfg:
|
||||
cfg[key] = getattr(module, key)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ class ConfigManager:
|
||||
self.file = cfg_file
|
||||
self.data = {}
|
||||
|
||||
async def load_config(self):
|
||||
self.data = await self.file.load()
|
||||
async def load_config(self, completion: bool=True):
|
||||
self.data = await self.file.load(completion=completion)
|
||||
|
||||
async def dump_config(self):
|
||||
await self.file.save(self.data)
|
||||
@@ -30,7 +30,7 @@ class ConfigManager:
|
||||
self.file.save_sync(self.data)
|
||||
|
||||
|
||||
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
|
||||
async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
|
||||
"""加载Python模块配置文件"""
|
||||
cfg_inst = pymodule.PythonModuleConfigFile(
|
||||
config_name,
|
||||
@@ -38,12 +38,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
|
||||
)
|
||||
|
||||
cfg_mgr = ConfigManager(cfg_inst)
|
||||
await cfg_mgr.load_config()
|
||||
await cfg_mgr.load_config(completion=completion)
|
||||
|
||||
return cfg_mgr
|
||||
|
||||
|
||||
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager:
|
||||
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
|
||||
"""加载JSON配置文件"""
|
||||
cfg_inst = json_file.JSONConfigFile(
|
||||
config_name,
|
||||
@@ -52,6 +52,6 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
|
||||
)
|
||||
|
||||
cfg_mgr = ConfigManager(cfg_inst)
|
||||
await cfg_mgr.load_config()
|
||||
await cfg_mgr.load_config(completion=completion)
|
||||
|
||||
return cfg_mgr
|
||||
35
pkg/config/migrations/m006_vision_and_oss_config.py
Normal file
35
pkg/config/migrations/m006_vision_and_oss_config.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("vision-and-oss-config", 6)
|
||||
class VisionAndOSSConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return "enable-vision" not in self.ap.provider_cfg.data \
|
||||
or "oss" not in self.ap.system_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
if "enable-vision" not in self.ap.provider_cfg.data:
|
||||
self.ap.provider_cfg.data["enable-vision"] = False
|
||||
|
||||
if "oss" not in self.ap.system_cfg.data:
|
||||
self.ap.system_cfg.data["oss"] = [
|
||||
{
|
||||
"type": "aliyun",
|
||||
"endpoint": "https://oss-cn-hangzhou.aliyuncs.com",
|
||||
"public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com",
|
||||
"access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5",
|
||||
"access-key-secret": "xxxxxx",
|
||||
"bucket": "qchatgpt",
|
||||
"prefix": "qchatgpt",
|
||||
"enable": False,
|
||||
}
|
||||
]
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
await self.ap.system_cfg.dump_config()
|
||||
@@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self) -> dict:
|
||||
async def load(self, completion: bool=True) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -15,6 +15,7 @@ from ..command import cmdmgr
|
||||
from ..plugin import manager as plugin_mgr
|
||||
from ..pipeline import pool
|
||||
from ..pipeline import controller, stagemgr
|
||||
from ..oss import oss
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr
|
||||
|
||||
|
||||
@@ -71,6 +72,8 @@ class Application:
|
||||
|
||||
proxy_mgr: proxy_mgr.ProxyManager = None
|
||||
|
||||
oss_mgr: oss.OSSServiceManager = None
|
||||
|
||||
logger: logging.Logger = None
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -14,6 +14,7 @@ required_deps = {
|
||||
"yaml": "pyyaml",
|
||||
"aiohttp": "aiohttp",
|
||||
"psutil": "psutil",
|
||||
"oss2": "oss2",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...platform import manager as im_mgr
|
||||
from ...oss import oss as oss_mgr
|
||||
|
||||
|
||||
@stage.stage_class("BuildAppStage")
|
||||
@@ -68,6 +69,10 @@ class BuildAppStage(stage.BootingStage):
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
|
||||
oss_mgr_inst = oss_mgr.OSSServiceManager(ap)
|
||||
await oss_mgr_inst.initialize()
|
||||
ap.oss_mgr = oss_mgr_inst
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
ap.model_mgr = llm_model_mgr_inst
|
||||
@@ -83,7 +88,6 @@ class BuildAppStage(stage.BootingStage):
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
@@ -92,5 +96,6 @@ class BuildAppStage(stage.BootingStage):
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
||||
@@ -12,11 +12,11 @@ class LoadConfigStage(stage.BootingStage):
|
||||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
|
||||
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
|
||||
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
|
||||
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
|
||||
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
|
||||
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False)
|
||||
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False)
|
||||
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False)
|
||||
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False)
|
||||
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False)
|
||||
|
||||
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
|
||||
await ap.plugin_setting_meta.dump_config()
|
||||
|
||||
@@ -5,7 +5,7 @@ import importlib
|
||||
from .. import stage, app
|
||||
from ...config import migration
|
||||
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||
from ...config.migrations import m005_deepseek_cfg_completion
|
||||
from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_and_oss_config
|
||||
|
||||
|
||||
@stage.stage_class("MigrationStage")
|
||||
|
||||
0
pkg/oss/__init__.py
Normal file
0
pkg/oss/__init__.py
Normal file
85
pkg/oss/oss.py
Normal file
85
pkg/oss/oss.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
import typing
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
import ssl
|
||||
|
||||
from . import service as osssv
|
||||
from ..core import app
|
||||
from .services import aliyun
|
||||
|
||||
|
||||
class OSSServiceManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
service: osssv.OSSService = None
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化
|
||||
"""
|
||||
|
||||
mapping = {}
|
||||
|
||||
for svcls in osssv.preregistered_services:
|
||||
mapping[svcls.name] = svcls
|
||||
|
||||
for sv in self.ap.system_cfg.data['oss']:
|
||||
if sv['enable']:
|
||||
|
||||
if sv['type'] not in mapping:
|
||||
raise Exception(f"未知的OSS服务类型: {sv['type']}")
|
||||
|
||||
self.service = mapping[sv['type']](self.ap, sv)
|
||||
await self.service.initialize()
|
||||
break
|
||||
|
||||
def available(self) -> bool:
|
||||
"""是否可用
|
||||
|
||||
Returns:
|
||||
bool: 是否可用
|
||||
"""
|
||||
return self.service is not None
|
||||
|
||||
async def fetch_image(self, image_url: str) -> bytes:
|
||||
parsed = urlparse(image_url)
|
||||
query = parse_qs(parsed.query)
|
||||
|
||||
# Flatten the query dictionary
|
||||
query = {k: v[0] for k, v in query.items()}
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=False) as session:
|
||||
async with session.get(
|
||||
f"http://{parsed.netloc}{parsed.path}",
|
||||
params=query,
|
||||
ssl=ssl_context
|
||||
) as resp:
|
||||
resp.raise_for_status() # 检查HTTP错误
|
||||
file_bytes = await resp.read()
|
||||
return file_bytes
|
||||
|
||||
async def upload_url_image(
|
||||
self,
|
||||
image_url: str,
|
||||
) -> str:
|
||||
"""上传URL图片
|
||||
|
||||
Args:
|
||||
image_url (str): 图片URL
|
||||
|
||||
Returns:
|
||||
str: 文件URL
|
||||
"""
|
||||
|
||||
file_bytes = await self.fetch_image(image_url)
|
||||
|
||||
return await self.service.upload(file_bytes=file_bytes, ext=".jpg")
|
||||
67
pkg/oss/service.py
Normal file
67
pkg/oss/service.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ..core import app
|
||||
|
||||
|
||||
preregistered_services: list[typing.Type[OSSService]] = []
|
||||
|
||||
def service_class(
|
||||
name: str
|
||||
) -> typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]:
|
||||
"""OSS服务类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 服务名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]: 装饰器
|
||||
"""
|
||||
def decorator(cls: typing.Type[OSSService]) -> typing.Type[OSSService]:
|
||||
assert issubclass(cls, OSSService)
|
||||
|
||||
cls.name = name
|
||||
|
||||
preregistered_services.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class OSSService(metaclass=abc.ABCMeta):
|
||||
"""OSS抽象类"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
cfg: dict
|
||||
|
||||
def __init__(self, ap: app.Application, cfg: dict) -> None:
|
||||
self.ap = ap
|
||||
self.cfg = cfg
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upload(
|
||||
self,
|
||||
local_file: str=None,
|
||||
file_bytes: bytes=None,
|
||||
ext: str=None,
|
||||
) -> str:
|
||||
"""上传文件
|
||||
|
||||
Args:
|
||||
local_file (str, optional): 本地文件路径. Defaults to None.
|
||||
file_bytes (bytes, optional): 文件字节. Defaults to None.
|
||||
ext (str, optional): 文件扩展名. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: 文件URL
|
||||
"""
|
||||
pass
|
||||
0
pkg/oss/services/__init__.py
Normal file
0
pkg/oss/services/__init__.py
Normal file
48
pkg/oss/services/aliyun.py
Normal file
48
pkg/oss/services/aliyun.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import oss2
|
||||
|
||||
from .. import service as osssv
|
||||
|
||||
|
||||
@osssv.service_class('aliyun')
|
||||
class AliyunOSSService(osssv.OSSService):
|
||||
"""阿里云OSS服务"""
|
||||
|
||||
auth: oss2.Auth
|
||||
|
||||
bucket: oss2.Bucket
|
||||
|
||||
async def initialize(self):
|
||||
self.auth = oss2.Auth(
|
||||
self.cfg['access-key-id'],
|
||||
self.cfg['access-key-secret']
|
||||
)
|
||||
|
||||
self.bucket = oss2.Bucket(
|
||||
self.auth,
|
||||
self.cfg['endpoint'],
|
||||
self.cfg['bucket']
|
||||
)
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
local_file: str=None,
|
||||
file_bytes: bytes=None,
|
||||
ext: str=None,
|
||||
) -> str:
|
||||
if local_file is not None:
|
||||
with open(local_file, 'rb') as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
if file_bytes is None:
|
||||
raise Exception("缺少文件内容")
|
||||
|
||||
name = str(uuid.uuid1())
|
||||
|
||||
key = f"{self.cfg['prefix']}/{name}{ext}"
|
||||
self.bucket.put_object(key, file_bytes)
|
||||
|
||||
return f"{self.cfg['public-read-base-url']}/{key}"
|
||||
@@ -9,6 +9,7 @@ from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
@stage.stage_class('PostContentFilterStage')
|
||||
@@ -141,6 +142,21 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"""处理
|
||||
"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
|
||||
contain_non_text = False
|
||||
|
||||
for me in query.message_chain:
|
||||
if not isinstance(me, mirai.Plain):
|
||||
contain_non_text = True
|
||||
break
|
||||
|
||||
if contain_non_text:
|
||||
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
return await self._pre_process(
|
||||
str(query.message_chain).strip(),
|
||||
query
|
||||
|
||||
@@ -4,6 +4,8 @@ import enum
|
||||
|
||||
import pydantic
|
||||
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
class ResultLevel(enum.Enum):
|
||||
"""结果等级"""
|
||||
@@ -29,6 +31,13 @@ class EnableStage(enum.Enum):
|
||||
"""后处理"""
|
||||
|
||||
|
||||
class AcceptContent(enum.Enum):
|
||||
"""过滤器接受的内容模态"""
|
||||
|
||||
TEXT = enum.auto()
|
||||
|
||||
IMAGE_URL = enum.auto()
|
||||
|
||||
class FilterResult(pydantic.BaseModel):
|
||||
level: ResultLevel
|
||||
"""结果等级
|
||||
@@ -38,7 +47,7 @@ class FilterResult(pydantic.BaseModel):
|
||||
"""
|
||||
|
||||
replacement: str
|
||||
"""替换后的消息
|
||||
"""替换后的文本消息
|
||||
|
||||
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
|
||||
若没有修改内容,也需要返回原消息。
|
||||
|
||||
@@ -5,6 +5,7 @@ import typing
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||
@@ -56,6 +57,16 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
entities.EnableStage.PRE,
|
||||
entities.EnableStage.POST
|
||||
]
|
||||
|
||||
@property
|
||||
def accept_content(self):
|
||||
"""本过滤器接受的模态
|
||||
|
||||
默认仅接受纯文本
|
||||
"""
|
||||
return [
|
||||
entities.AcceptContent.TEXT
|
||||
]
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化过滤器
|
||||
@@ -63,7 +74,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
@@ -71,6 +82,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
|
||||
Args:
|
||||
message (str): 需要检查的内容
|
||||
image_url (str): 要检查的图片的 URL
|
||||
|
||||
Returns:
|
||||
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...provider import entities as llm_entities
|
||||
@@ -37,9 +39,31 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
# 检查vision是否启用,没启用就删除所有图片
|
||||
if not self.ap.provider_cfg.data['enable-vision']:
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
if me.type == 'image_url':
|
||||
msg.content.remove(me)
|
||||
|
||||
content_list = []
|
||||
|
||||
for me in query.message_chain:
|
||||
if isinstance(me, mirai.Plain):
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_text(me.text)
|
||||
)
|
||||
elif isinstance(me, mirai.Image):
|
||||
if self.ap.provider_cfg.data['enable-vision']:
|
||||
if me.url is not None:
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_image_url(str(me.url))
|
||||
)
|
||||
|
||||
query.user_message = llm_entities.Message( # TODO 适配多模态输入
|
||||
role='user',
|
||||
content=str(query.message_chain).strip()
|
||||
content=content_list
|
||||
)
|
||||
|
||||
query.use_model = conversation.use_model
|
||||
|
||||
@@ -93,15 +93,28 @@ class CommandHandler(handler.MessageHandler):
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif ret.text is not None:
|
||||
elif ret.text is not None or ret.image_url is not None:
|
||||
|
||||
content: list[llm_entities.ContentElement]= []
|
||||
|
||||
if ret.text is not None:
|
||||
content.append(
|
||||
llm_entities.ContentElement.from_text(ret.text)
|
||||
)
|
||||
|
||||
if ret.image_url is not None:
|
||||
content.append(
|
||||
llm_entities.ContentElement.from_image_url(ret.image_url)
|
||||
)
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='command',
|
||||
content=ret.text,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}')
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
||||
@@ -34,17 +34,19 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
"""
|
||||
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
|
||||
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif query.resp_messages[-1].role == 'plugin':
|
||||
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
|
||||
query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
|
||||
else:
|
||||
query.resp_message_chain.append(query.resp_messages[-1].content)
|
||||
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
|
||||
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
|
||||
# else:
|
||||
# query.resp_message_chain.append(query.resp_messages[-1].content)
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
@@ -59,7 +61,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
reply_text = ''
|
||||
|
||||
if result.content is not None: # 有内容
|
||||
reply_text = result.content
|
||||
reply_text = str(result.get_content_mirai_message_chain())
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
@@ -87,7 +89,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
|
||||
query.resp_message_chain.append(result.get_content_mirai_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
||||
@@ -21,14 +21,34 @@ class ToolCall(pydantic.BaseModel):
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class Content(pydantic.BaseModel):
|
||||
class ImageURLContentObject(pydantic.BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class ContentElement(pydantic.BaseModel):
|
||||
|
||||
type: str
|
||||
"""内容类型"""
|
||||
|
||||
text: typing.Optional[str] = None
|
||||
|
||||
image_url: typing.Optional[str] = None
|
||||
image_url: typing.Optional[ImageURLContentObject] = None
|
||||
|
||||
def __str__(self):
|
||||
if self.type == 'text':
|
||||
return self.text
|
||||
elif self.type == 'image_url':
|
||||
return f'[图片]({self.image_url})'
|
||||
else:
|
||||
return '未知内容'
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str):
|
||||
return cls(type='text', text=text)
|
||||
|
||||
@classmethod
|
||||
def from_image_url(cls, image_url: str):
|
||||
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
|
||||
|
||||
|
||||
class Message(pydantic.BaseModel):
|
||||
@@ -40,7 +60,7 @@ class Message(pydantic.BaseModel):
|
||||
name: typing.Optional[str] = None
|
||||
"""名称,仅函数调用返回时设置"""
|
||||
|
||||
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None
|
||||
content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
|
||||
"""内容"""
|
||||
|
||||
tool_calls: typing.Optional[list[ToolCall]] = None
|
||||
@@ -50,8 +70,38 @@ class Message(pydantic.BaseModel):
|
||||
|
||||
def readable_str(self) -> str:
|
||||
if self.content is not None:
|
||||
return str(self.role) + ": " + str(self.content)
|
||||
return str(self.role) + ": " + str(self.get_content_mirai_message_chain())
|
||||
elif self.tool_calls is not None:
|
||||
return f'调用工具: {self.tool_calls[0].id}'
|
||||
else:
|
||||
return '未知消息'
|
||||
|
||||
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
|
||||
"""将内容转换为 Mirai MessageChain 对象
|
||||
|
||||
Args:
|
||||
prefix_text (str): 首个文字组件的前缀文本
|
||||
"""
|
||||
|
||||
if self.content is None:
|
||||
return None
|
||||
elif isinstance(self.content, str):
|
||||
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
|
||||
elif isinstance(self.content, list):
|
||||
mc = []
|
||||
for ce in self.content:
|
||||
if ce.type == 'text':
|
||||
mc.append(mirai.Plain(ce.text))
|
||||
elif ce.type == 'image':
|
||||
mc.append(mirai.Image(url=ce.image_url))
|
||||
|
||||
# 找第一个文字组件
|
||||
if prefix_text:
|
||||
for i, c in enumerate(mc):
|
||||
if isinstance(c, mirai.Plain):
|
||||
mc[i] = mirai.Plain(prefix_text+c.text)
|
||||
break
|
||||
else:
|
||||
mc.insert(0, mirai.Plain(prefix_text))
|
||||
|
||||
return mirai.MessageChain(mc)
|
||||
|
||||
@@ -38,30 +38,42 @@ class AnthropicMessages(api.LLMAPIRequester):
|
||||
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
|
||||
args["model"] = model.name if model.model_name is None else model.model_name
|
||||
|
||||
req_messages = [
|
||||
m.dict(exclude_none=True) for m in messages if m.content.strip() != ""
|
||||
]
|
||||
# 处理消息
|
||||
|
||||
# 删除所有 role=system & content='' 的消息
|
||||
req_messages = [
|
||||
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
|
||||
]
|
||||
# system
|
||||
system_role_message = None
|
||||
|
||||
# 检查是否有 role=system 的消息,若有,改为 role=user,并在后面加一个 role=assistant 的消息
|
||||
system_role_index = []
|
||||
for i, m in enumerate(req_messages):
|
||||
if m["role"] == "system":
|
||||
system_role_index.append(i)
|
||||
m["role"] = "user"
|
||||
for i, m in enumerate(messages):
|
||||
if m.role == "system":
|
||||
system_role_message = m
|
||||
|
||||
if system_role_index:
|
||||
for i in system_role_index[::-1]:
|
||||
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
|
||||
messages.pop(i)
|
||||
break
|
||||
|
||||
# 忽略掉空消息,用户可能发送空消息,而上层未过滤
|
||||
req_messages = [
|
||||
m for m in req_messages if m["content"].strip() != ""
|
||||
]
|
||||
if isinstance(system_role_message, llm_entities.Message) \
|
||||
and isinstance(system_role_message.content, str):
|
||||
args['system'] = system_role_message.content
|
||||
|
||||
# 其他消息
|
||||
# req_messages = [
|
||||
# m.dict(exclude_none=True) for m in messages \
|
||||
# if (isinstance(m.content, str) and m.content.strip() != "") \
|
||||
# or (isinstance(m.content, list) and )
|
||||
# ]
|
||||
# 暂时不支持vision,仅保留纯文字的content
|
||||
req_messages = []
|
||||
|
||||
for m in messages:
|
||||
if isinstance(m.content, str) and m.content.strip() != "":
|
||||
req_messages.append(m.dict(exclude_none=True))
|
||||
elif isinstance(m.content, list):
|
||||
# 删除m.content中的type!=text的元素
|
||||
m.content = [
|
||||
c for c in m.content if c.get("type") == "text"
|
||||
]
|
||||
|
||||
if len(m.content) > 0:
|
||||
req_messages.append(m.dict(exclude_none=True))
|
||||
|
||||
args["messages"] = req_messages
|
||||
|
||||
|
||||
@@ -23,9 +23,18 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
|
||||
|
||||
requester_cfg: dict
|
||||
|
||||
cached_image_oss_url: dict[str, str] = {}
|
||||
"""缓存的OSS服务的图片URL
|
||||
|
||||
key: 前文message中的原图片URL(QQ图片)
|
||||
value: OSS服务的图片URL
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.cached_image_oss_url = {}
|
||||
|
||||
self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions']
|
||||
|
||||
async def initialize(self):
|
||||
@@ -74,7 +83,16 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
# 检查vision
|
||||
if self.ap.oss_mgr.available():
|
||||
for msg in messages:
|
||||
if isinstance(msg["content"], list):
|
||||
for me in msg["content"]:
|
||||
if me["type"] == "image_url":
|
||||
me["image_url"]['url'] = await self.get_oss_url(me["image_url"]['url'])
|
||||
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
@@ -112,3 +130,17 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
async def get_oss_url(
|
||||
self,
|
||||
original_url: str,
|
||||
) -> str:
|
||||
|
||||
if original_url in self.cached_image_oss_url:
|
||||
return self.cached_image_oss_url[original_url]
|
||||
|
||||
oss_url = await self.ap.oss_mgr.upload_url_image(original_url)
|
||||
|
||||
self.cached_image_oss_url[original_url] = oss_url
|
||||
|
||||
return oss_url
|
||||
|
||||
@@ -3,7 +3,10 @@ from __future__ import annotations
|
||||
from ....core import app
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import api
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
@api.requester_class("deepseek-chat-completions")
|
||||
@@ -12,4 +15,39 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions']
|
||||
self.ap = ap
|
||||
self.ap = ap
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_model.tool_call_supported:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
|
||||
# deepseek 不支持多模态,把content都转换成纯文字
|
||||
for m in messages:
|
||||
if isinstance(m["content"], list):
|
||||
m["content"] = " ".join([c["text"] for c in m["content"]])
|
||||
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
@@ -3,7 +3,10 @@ from __future__ import annotations
|
||||
from ....core import app
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import api
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
@api.requester_class("moonshot-chat-completions")
|
||||
@@ -13,3 +16,41 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
def __init__(self, ap: app.Application):
|
||||
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
|
||||
self.ap = ap
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_model.tool_call_supported:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
|
||||
# deepseek 不支持多模态,把content都转换成纯文字
|
||||
for m in messages:
|
||||
if isinstance(m["content"], list):
|
||||
m["content"] = " ".join([c["text"] for c in m["content"]])
|
||||
|
||||
# 删除空的
|
||||
messages = [m for m in messages if m["content"].strip() != ""]
|
||||
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
@@ -37,6 +37,10 @@ class ModelManager:
|
||||
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
# 检查是否启用了vision但是没有配置oss
|
||||
if self.ap.provider_cfg.data['enable-vision'] and not self.ap.oss_mgr.available():
|
||||
self.ap.logger.warn("启用了视觉但是没有配置可用的oss服务,基于 URL 传递图片的视觉 API 将无法正常使用")
|
||||
|
||||
# 初始化token_mgr, requester
|
||||
for k, v in self.ap.provider_cfg.data['keys'].items():
|
||||
|
||||
@@ -13,4 +13,6 @@ aiohttp
|
||||
pydantic
|
||||
websockets
|
||||
urllib3
|
||||
psutil
|
||||
psutil
|
||||
|
||||
oss2
|
||||
@@ -22,6 +22,10 @@
|
||||
"name": "gpt-4-32k",
|
||||
"tool_call_supported": true
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"tool_call_supported": true
|
||||
},
|
||||
{
|
||||
"model_name": "SparkDesk",
|
||||
"name": "OneAPI/SparkDesk"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"enable-chat": true,
|
||||
"enable-vision": false,
|
||||
"keys": {
|
||||
"openai": [
|
||||
"sk-1234567890"
|
||||
|
||||
@@ -1,5 +1,17 @@
|
||||
{
|
||||
"admin-sessions": [],
|
||||
"oss": [
|
||||
{
|
||||
"type": "aliyun",
|
||||
"endpoint": "https://oss-cn-hangzhou.aliyuncs.com",
|
||||
"public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com",
|
||||
"access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5",
|
||||
"access-key-secret": "xxxxxx",
|
||||
"bucket": "qchatgpt",
|
||||
"prefix": "qchatgpt",
|
||||
"enable": false
|
||||
}
|
||||
],
|
||||
"network-proxies": {
|
||||
"http": null,
|
||||
"https": null
|
||||
|
||||
Reference in New Issue
Block a user