diff --git a/pkg/command/entities.py b/pkg/command/entities.py index 27cb5962..86975510 100644 --- a/pkg/command/entities.py +++ b/pkg/command/entities.py @@ -13,11 +13,16 @@ class CommandReturn(pydantic.BaseModel): """命令返回值 """ - text: typing.Optional[str] + text: typing.Optional[str] = None """文本 """ - image: typing.Optional[mirai.Image] + image: typing.Optional[mirai.Image] = None + """弃用""" + + image_url: typing.Optional[str] = None + """图片链接 + """ error: typing.Optional[errors.CommandError]= None """错误 diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py index 754bfa58..362bc78a 100644 --- a/pkg/config/impls/json.py +++ b/pkg/config/impls/json.py @@ -27,7 +27,7 @@ class JSONConfigFile(file_model.ConfigFile): else: raise ValueError("template_file_name or template_data must be provided") - async def load(self) -> dict: + async def load(self, completion: bool=True) -> dict: if not self.exists(): await self.create() @@ -39,9 +39,11 @@ class JSONConfigFile(file_model.ConfigFile): with open(self.config_file_name, "r", encoding="utf-8") as f: cfg = json.load(f) - for key in self.template_data: - if key not in cfg: - cfg[key] = self.template_data[key] + if completion: + + for key in self.template_data: + if key not in cfg: + cfg[key] = self.template_data[key] return cfg diff --git a/pkg/config/impls/pymodule.py b/pkg/config/impls/pymodule.py index ceeebad1..67e5867d 100644 --- a/pkg/config/impls/pymodule.py +++ b/pkg/config/impls/pymodule.py @@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile): async def create(self): shutil.copyfile(self.template_file_name, self.config_file_name) - async def load(self) -> dict: + async def load(self, completion: bool=True) -> dict: module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] module = importlib.import_module(module_name) @@ -43,18 +43,19 @@ class PythonModuleConfigFile(file_model.ConfigFile): cfg[key] = getattr(module, key) # 从模板模块文件中进行补全 - module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] - module = importlib.import_module(module_name) + if completion: + module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] + module = importlib.import_module(module_name) - for key in dir(module): - if key.startswith('__'): - continue + for key in dir(module): + if key.startswith('__'): + continue - if not isinstance(getattr(module, key), allowed_types): - continue + if not isinstance(getattr(module, key), allowed_types): + continue - if key not in cfg: - cfg[key] = getattr(module, key) + if key not in cfg: + cfg[key] = getattr(module, key) return cfg diff --git a/pkg/config/manager.py b/pkg/config/manager.py index f9e93c81..7983407c 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -20,8 +20,8 @@ class ConfigManager: self.file = cfg_file self.data = {} - async def load_config(self): - self.data = await self.file.load() + async def load_config(self, completion: bool=True): + self.data = await self.file.load(completion=completion) async def dump_config(self): await self.file.save(self.data) @@ -30,7 +30,7 @@ class ConfigManager: self.file.save_sync(self.data) -async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: +async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager: """加载Python模块配置文件""" cfg_inst = pymodule.PythonModuleConfigFile( config_name, @@ -38,12 +38,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con ) cfg_mgr = ConfigManager(cfg_inst) - await cfg_mgr.load_config() + await cfg_mgr.load_config(completion=completion) return cfg_mgr -async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager: +async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager: """加载JSON配置文件""" cfg_inst = json_file.JSONConfigFile( config_name, @@ -52,6 +52,6 @@ async def load_json_config(config_name: str, template_name: str=None, template_d ) cfg_mgr = ConfigManager(cfg_inst) - await cfg_mgr.load_config() + await cfg_mgr.load_config(completion=completion) return cfg_mgr \ No newline at end of file diff --git a/pkg/config/migrations/m006_vision_and_oss_config.py b/pkg/config/migrations/m006_vision_and_oss_config.py new file mode 100644 index 00000000..2d05b975 --- /dev/null +++ b/pkg/config/migrations/m006_vision_and_oss_config.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("vision-and-oss-config", 6) +class VisionAndOSSConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return "enable-vision" not in self.ap.provider_cfg.data \ + or "oss" not in self.ap.system_cfg.data + + async def run(self): + """执行迁移""" + if "enable-vision" not in self.ap.provider_cfg.data: + self.ap.provider_cfg.data["enable-vision"] = False + + if "oss" not in self.ap.system_cfg.data: + self.ap.system_cfg.data["oss"] = [ + { + "type": "aliyun", + "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", + "public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com", + "access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5", + "access-key-secret": "xxxxxx", + "bucket": "qchatgpt", + "prefix": "qchatgpt", + "enable": False, + } + ] + + await self.ap.provider_cfg.dump_config() + await self.ap.system_cfg.dump_config() diff --git a/pkg/config/model.py b/pkg/config/model.py index d209093c..153123e3 100644 --- a/pkg/config/model.py +++ b/pkg/config/model.py @@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def load(self) -> dict: + async def load(self, completion: bool=True) -> dict: pass @abc.abstractmethod diff --git a/pkg/core/app.py b/pkg/core/app.py index 1ed53042..1705e299 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -15,6 +15,7 @@ from ..command import cmdmgr from ..plugin import manager as plugin_mgr from ..pipeline import pool from ..pipeline import controller, stagemgr +from ..oss import oss from ..utils import version as version_mgr, proxy as proxy_mgr @@ -71,6 +72,8 @@ class Application: proxy_mgr: proxy_mgr.ProxyManager = None + oss_mgr: oss.OSSServiceManager = None + logger: logging.Logger = None def __init__(self): diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 4adf1323..ab407048 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -14,6 +14,7 @@ required_deps = { "yaml": "pyyaml", "aiohttp": "aiohttp", "psutil": "psutil", + "oss2": "oss2", } diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index d5365826..ff8c1e74 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -14,6 +14,7 @@ from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr from ...platform import manager as im_mgr +from ...oss import oss as oss_mgr @stage.stage_class("BuildAppStage") @@ -68,6 +69,10 @@ class BuildAppStage(stage.BootingStage): await cmd_mgr_inst.initialize() ap.cmd_mgr = cmd_mgr_inst + oss_mgr_inst = oss_mgr.OSSServiceManager(ap) + await oss_mgr_inst.initialize() + ap.oss_mgr = oss_mgr_inst + llm_model_mgr_inst = llm_model_mgr.ModelManager(ap) await llm_model_mgr_inst.initialize() ap.model_mgr = llm_model_mgr_inst @@ -83,7 +88,6 @@ class BuildAppStage(stage.BootingStage): llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) await llm_tool_mgr_inst.initialize() ap.tool_mgr = llm_tool_mgr_inst - im_mgr_inst = im_mgr.PlatformManager(ap=ap) await im_mgr_inst.initialize() ap.platform_mgr = im_mgr_inst @@ -92,5 +96,6 @@ class BuildAppStage(stage.BootingStage): await stage_mgr.initialize() ap.stage_mgr = stage_mgr + ctrl = controller.Controller(ap) ap.ctrl = ctrl diff --git a/pkg/core/stages/load_config.py b/pkg/core/stages/load_config.py index 9e61c1cb..cb6e1ed0 100644 --- a/pkg/core/stages/load_config.py +++ b/pkg/core/stages/load_config.py @@ -12,11 +12,11 @@ class LoadConfigStage(stage.BootingStage): async def run(self, ap: app.Application): """启动 """ - ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json") - ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json") - ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json") - ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json") - ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json") + ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False) + ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False) + ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False) + ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False) + ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False) ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json") await ap.plugin_setting_meta.dump_config() diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index cef3b42d..44102fea 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -5,7 +5,7 @@ import importlib from .. import stage, app from ...config import migration from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion -from ...config.migrations import m005_deepseek_cfg_completion +from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_and_oss_config @stage.stage_class("MigrationStage") diff --git a/pkg/oss/__init__.py b/pkg/oss/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/oss/oss.py b/pkg/oss/oss.py new file mode 100644 index 00000000..5474ed39 --- /dev/null +++ b/pkg/oss/oss.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import aiohttp +import typing +from urllib.parse import urlparse, parse_qs +import ssl + +from . import service as osssv +from ..core import app +from .services import aliyun + + +class OSSServiceManager: + + ap: app.Application + + service: osssv.OSSService = None + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + """初始化 + """ + + mapping = {} + + for svcls in osssv.preregistered_services: + mapping[svcls.name] = svcls + + for sv in self.ap.system_cfg.data['oss']: + if sv['enable']: + + if sv['type'] not in mapping: + raise Exception(f"未知的OSS服务类型: {sv['type']}") + + self.service = mapping[sv['type']](self.ap, sv) + await self.service.initialize() + break + + def available(self) -> bool: + """是否可用 + + Returns: + bool: 是否可用 + """ + return self.service is not None + + async def fetch_image(self, image_url: str) -> bytes: + parsed = urlparse(image_url) + query = parse_qs(parsed.query) + + # Flatten the query dictionary + query = {k: v[0] for k, v in query.items()} + + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + async with aiohttp.ClientSession(trust_env=False) as session: + async with session.get( + f"http://{parsed.netloc}{parsed.path}", + params=query, + ssl=ssl_context + ) as resp: + resp.raise_for_status() # 检查HTTP错误 + file_bytes = await resp.read() + return file_bytes + + async def upload_url_image( + self, + image_url: str, + ) -> str: + """上传URL图片 + + Args: + image_url (str): 图片URL + + Returns: + str: 文件URL + """ + + file_bytes = await self.fetch_image(image_url) + + return await self.service.upload(file_bytes=file_bytes, ext=".jpg") \ No newline at end of file diff --git a/pkg/oss/service.py b/pkg/oss/service.py new file mode 100644 index 00000000..a8228447 --- /dev/null +++ b/pkg/oss/service.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import typing +import abc + +from ..core import app + + +preregistered_services: list[typing.Type[OSSService]] = [] + +def service_class( + name: str +) -> typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]: + """OSS服务类装饰器 + + Args: + name (str): 服务名称 + + Returns: + typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]: 装饰器 + """ + def decorator(cls: typing.Type[OSSService]) -> typing.Type[OSSService]: + assert issubclass(cls, OSSService) + + cls.name = name + + preregistered_services.append(cls) + + return cls + + return decorator + + +class OSSService(metaclass=abc.ABCMeta): + """OSS抽象类""" + + name: str + + ap: app.Application + + cfg: dict + + def __init__(self, ap: app.Application, cfg: dict) -> None: + self.ap = ap + self.cfg = cfg + + async def initialize(self): + pass + + @abc.abstractmethod + async def upload( + self, + local_file: str=None, + file_bytes: bytes=None, + ext: str=None, + ) -> str: + """上传文件 + + Args: + local_file (str, optional): 本地文件路径. Defaults to None. + file_bytes (bytes, optional): 文件字节. Defaults to None. + ext (str, optional): 文件扩展名. Defaults to None. + + Returns: + str: 文件URL + """ + pass diff --git a/pkg/oss/services/__init__.py b/pkg/oss/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/oss/services/aliyun.py b/pkg/oss/services/aliyun.py new file mode 100644 index 00000000..d30ac895 --- /dev/null +++ b/pkg/oss/services/aliyun.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import uuid + +import oss2 + +from .. import service as osssv + + +@osssv.service_class('aliyun') +class AliyunOSSService(osssv.OSSService): + """阿里云OSS服务""" + + auth: oss2.Auth + + bucket: oss2.Bucket + + async def initialize(self): + self.auth = oss2.Auth( + self.cfg['access-key-id'], + self.cfg['access-key-secret'] + ) + + self.bucket = oss2.Bucket( + self.auth, + self.cfg['endpoint'], + self.cfg['bucket'] + ) + + async def upload( + self, + local_file: str=None, + file_bytes: bytes=None, + ext: str=None, + ) -> str: + if local_file is not None: + with open(local_file, 'rb') as f: + file_bytes = f.read() + + if file_bytes is None: + raise Exception("缺少文件内容") + + name = str(uuid.uuid1()) + + key = f"{self.cfg['prefix']}/{name}{ext}" + self.bucket.put_object(key, file_bytes) + + return f"{self.cfg['public-read-base-url']}/{key}" diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 2c6a5ab9..a669e310 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -9,6 +9,7 @@ from ...core import entities as core_entities from ...config import manager as cfg_mgr from . import filter as filter_model, entities as filter_entities from .filters import cntignore, banwords, baiduexamine +from ...provider import entities as llm_entities @stage.stage_class('PostContentFilterStage') @@ -141,6 +142,21 @@ class ContentFilterStage(stage.PipelineStage): """处理 """ if stage_inst_name == 'PreContentFilterStage': + + contain_non_text = False + + for me in query.message_chain: + if not isinstance(me, mirai.Plain): + contain_non_text = True + break + + if contain_non_text: + self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。") + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + return await self._pre_process( str(query.message_chain).strip(), query diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py index 8ff581fb..83744c1d 100644 --- a/pkg/pipeline/cntfilter/entities.py +++ b/pkg/pipeline/cntfilter/entities.py @@ -4,6 +4,8 @@ import enum import pydantic +from ...provider import entities as llm_entities + class ResultLevel(enum.Enum): """结果等级""" @@ -29,6 +31,13 @@ class EnableStage(enum.Enum): """后处理""" +class AcceptContent(enum.Enum): + """过滤器接受的内容模态""" + + TEXT = enum.auto() + + IMAGE_URL = enum.auto() + class FilterResult(pydantic.BaseModel): level: ResultLevel """结果等级 @@ -38,7 +47,7 @@ class FilterResult(pydantic.BaseModel): """ replacement: str - """替换后的消息 + """替换后的文本消息 内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。 若没有修改内容,也需要返回原消息。 diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 8b34e0c5..5fd55e42 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -5,6 +5,7 @@ import typing from ...core import app from . import entities +from ...provider import entities as llm_entities preregistered_filters: list[typing.Type[ContentFilter]] = [] @@ -56,6 +57,16 @@ class ContentFilter(metaclass=abc.ABCMeta): entities.EnableStage.PRE, entities.EnableStage.POST ] + + @property + def accept_content(self): + """本过滤器接受的模态 + + 默认仅接受纯文本 + """ + return [ + entities.AcceptContent.TEXT + ] async def initialize(self): """初始化过滤器 @@ -63,7 +74,7 @@ class ContentFilter(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, message: str) -> entities.FilterResult: + async def process(self, message: str=None, image_url=None) -> entities.FilterResult: """处理消息 分为前后阶段,具体取决于 enable_stages 的值。 @@ -71,6 +82,7 @@ class ContentFilter(metaclass=abc.ABCMeta): Args: message (str): 需要检查的内容 + image_url (str): 要检查的图片的 URL Returns: entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档 diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 164f78c8..0470a607 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -1,5 +1,7 @@ from __future__ import annotations +import mirai + from .. import stage, entities, stagemgr from ...core import entities as core_entities from ...provider import entities as llm_entities @@ -37,9 +39,31 @@ class PreProcessor(stage.PipelineStage): query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() + # 检查vision是否启用,没启用就删除所有图片 + if not self.ap.provider_cfg.data['enable-vision']: + for msg in query.messages: + if isinstance(msg.content, list): + for me in msg.content: + if me.type == 'image_url': + msg.content.remove(me) + + content_list = [] + + for me in query.message_chain: + if isinstance(me, mirai.Plain): + content_list.append( + llm_entities.ContentElement.from_text(me.text) + ) + elif isinstance(me, mirai.Image): + if self.ap.provider_cfg.data['enable-vision']: + if me.url is not None: + content_list.append( + llm_entities.ContentElement.from_image_url(str(me.url)) + ) + query.user_message = llm_entities.Message( # TODO 适配多模态输入 role='user', - content=str(query.message_chain).strip() + content=content_list ) query.use_model = conversation.use_model diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index 75d9222c..02ff269a 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -93,15 +93,28 @@ class CommandHandler(handler.MessageHandler): result_type=entities.ResultType.CONTINUE, new_query=query ) - elif ret.text is not None: + elif ret.text is not None or ret.image_url is not None: + + content: list[llm_entities.ContentElement]= [] + + if ret.text is not None: + content.append( + llm_entities.ContentElement.from_text(ret.text) + ) + + if ret.image_url is not None: + content.append( + llm_entities.ContentElement.from_image_url(ret.image_url) + ) + query.resp_messages.append( llm_entities.Message( role='command', - content=ret.text, + content=content, ) ) - self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}') + self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}') yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index 345addb8..acf0549d 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -34,17 +34,19 @@ class ResponseWrapper(stage.PipelineStage): """ if query.resp_messages[-1].role == 'command': - query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) + # query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) + query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] ')) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) elif query.resp_messages[-1].role == 'plugin': - if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): - query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content)) - else: - query.resp_message_chain.append(query.resp_messages[-1].content) + # if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): + # query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content)) + # else: + # query.resp_message_chain.append(query.resp_messages[-1].content) + query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -59,7 +61,7 @@ class ResponseWrapper(stage.PipelineStage): reply_text = '' if result.content is not None: # 有内容 - reply_text = result.content + reply_text = str(result.get_content_mirai_message_chain()) # ============= 触发插件事件 =============== event_ctx = await self.ap.plugin_mgr.emit_event( @@ -87,7 +89,7 @@ class ResponseWrapper(stage.PipelineStage): else: - query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) + query.resp_message_chain.append(result.get_content_mirai_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 3281a93b..b892b89b 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -21,14 +21,34 @@ class ToolCall(pydantic.BaseModel): function: FunctionCall -class Content(pydantic.BaseModel): +class ImageURLContentObject(pydantic.BaseModel): + url: str + + +class ContentElement(pydantic.BaseModel): type: str """内容类型""" text: typing.Optional[str] = None - image_url: typing.Optional[str] = None + image_url: typing.Optional[ImageURLContentObject] = None + + def __str__(self): + if self.type == 'text': + return self.text + elif self.type == 'image_url': + return f'[图片]({self.image_url})' + else: + return '未知内容' + + @classmethod + def from_text(cls, text: str): + return cls(type='text', text=text) + + @classmethod + def from_image_url(cls, image_url: str): + return cls(type='image_url', image_url=ImageURLContentObject(url=image_url)) class Message(pydantic.BaseModel): @@ -40,7 +60,7 @@ class Message(pydantic.BaseModel): name: typing.Optional[str] = None """名称,仅函数调用返回时设置""" - content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None + content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None """内容""" tool_calls: typing.Optional[list[ToolCall]] = None @@ -50,8 +70,38 @@ class Message(pydantic.BaseModel): def readable_str(self) -> str: if self.content is not None: - return str(self.role) + ": " + str(self.content) + return str(self.role) + ": " + str(self.get_content_mirai_message_chain()) elif self.tool_calls is not None: return f'调用工具: {self.tool_calls[0].id}' else: return '未知消息' + + def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None: + """将内容转换为 Mirai MessageChain 对象 + + Args: + prefix_text (str): 首个文字组件的前缀文本 + """ + + if self.content is None: + return None + elif isinstance(self.content, str): + return mirai.MessageChain([mirai.Plain(prefix_text+self.content)]) + elif isinstance(self.content, list): + mc = [] + for ce in self.content: + if ce.type == 'text': + mc.append(mirai.Plain(ce.text)) + elif ce.type == 'image': + mc.append(mirai.Image(url=ce.image_url)) + + # 找第一个文字组件 + if prefix_text: + for i, c in enumerate(mc): + if isinstance(c, mirai.Plain): + mc[i] = mirai.Plain(prefix_text+c.text) + break + else: + mc.insert(0, mirai.Plain(prefix_text)) + + return mirai.MessageChain(mc) diff --git a/pkg/provider/modelmgr/apis/anthropicmsgs.py b/pkg/provider/modelmgr/apis/anthropicmsgs.py index 923e1ceb..ee2c51a5 100644 --- a/pkg/provider/modelmgr/apis/anthropicmsgs.py +++ b/pkg/provider/modelmgr/apis/anthropicmsgs.py @@ -38,30 +38,42 @@ class AnthropicMessages(api.LLMAPIRequester): args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() args["model"] = model.name if model.model_name is None else model.model_name - req_messages = [ - m.dict(exclude_none=True) for m in messages if m.content.strip() != "" - ] + # 处理消息 - # 删除所有 role=system & content='' 的消息 - req_messages = [ - m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "") - ] + # system + system_role_message = None - # 检查是否有 role=system 的消息,若有,改为 role=user,并在后面加一个 role=assistant 的消息 - system_role_index = [] - for i, m in enumerate(req_messages): - if m["role"] == "system": - system_role_index.append(i) - m["role"] = "user" + for i, m in enumerate(messages): + if m.role == "system": + system_role_message = m - if system_role_index: - for i in system_role_index[::-1]: - req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."}) + messages.pop(i) + break - # 忽略掉空消息,用户可能发送空消息,而上层未过滤 - req_messages = [ - m for m in req_messages if m["content"].strip() != "" - ] + if isinstance(system_role_message, llm_entities.Message) \ + and isinstance(system_role_message.content, str): + args['system'] = system_role_message.content + + # 其他消息 + # req_messages = [ + # m.dict(exclude_none=True) for m in messages \ + # if (isinstance(m.content, str) and m.content.strip() != "") \ + # or (isinstance(m.content, list) and ) + # ] + # 暂时不支持vision,仅保留纯文字的content + req_messages = [] + + for m in messages: + if isinstance(m.content, str) and m.content.strip() != "": + req_messages.append(m.dict(exclude_none=True)) + elif isinstance(m.content, list): + # 删除m.content中的type!=text的元素 + m.content = [ + c for c in m.content if c.get("type") == "text" + ] + + if len(m.content) > 0: + req_messages.append(m.dict(exclude_none=True)) args["messages"] = req_messages diff --git a/pkg/provider/modelmgr/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py index 7984dd83..fb6e0575 100644 --- a/pkg/provider/modelmgr/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -23,9 +23,18 @@ class OpenAIChatCompletions(api.LLMAPIRequester): requester_cfg: dict + cached_image_oss_url: dict[str, str] = {} + """缓存的OSS服务的图片URL + + key: 前文message中的原图片URL(QQ图片) + value: OSS服务的图片URL + """ + def __init__(self, ap: app.Application): self.ap = ap + self.cached_image_oss_url = {} + self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions'] async def initialize(self): @@ -74,7 +83,16 @@ class OpenAIChatCompletions(api.LLMAPIRequester): args["tools"] = tools # 设置此次请求中的messages - messages = req_messages + messages = req_messages.copy() + + # 检查vision + if self.ap.oss_mgr.available(): + for msg in messages: + if isinstance(msg["content"], list): + for me in msg["content"]: + if me["type"] == "image_url": + me["image_url"]['url'] = await self.get_oss_url(me["image_url"]['url']) + args["messages"] = messages # 发送请求 @@ -112,3 +130,17 @@ class OpenAIChatCompletions(api.LLMAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def get_oss_url( + self, + original_url: str, + ) -> str: + + if original_url in self.cached_image_oss_url: + return self.cached_image_oss_url[original_url] + + oss_url = await self.ap.oss_mgr.upload_url_image(original_url) + + self.cached_image_oss_url[original_url] = oss_url + + return oss_url diff --git a/pkg/provider/modelmgr/apis/deepseekchatcmpl.py b/pkg/provider/modelmgr/apis/deepseekchatcmpl.py index dd8ddc6d..9cb667b7 100644 --- a/pkg/provider/modelmgr/apis/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/apis/deepseekchatcmpl.py @@ -3,7 +3,10 @@ from __future__ import annotations from ....core import app from . import chatcmpl -from .. import api +from .. import api, entities, errors +from ....core import entities as core_entities, app +from ... import entities as llm_entities +from ...tools import entities as tools_entities @api.requester_class("deepseek-chat-completions") @@ -12,4 +15,39 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): def __init__(self, ap: app.Application): self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions'] - self.ap = ap \ No newline at end of file + self.ap = ap + + async def _closure( + self, + req_messages: list[dict], + use_model: entities.LLMModelInfo, + use_funcs: list[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = use_model.token_mgr.get_token() + + args = self.requester_cfg['args'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + if use_model.tool_call_supported: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args["tools"] = tools + + # 设置此次请求中的messages + messages = req_messages + + # deepseek 不支持多模态,把content都转换成纯文字 + for m in messages: + if isinstance(m["content"], list): + m["content"] = " ".join([c["text"] for c in m["content"]]) + + args["messages"] = messages + + # 发送请求 + resp = await self._req(args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message \ No newline at end of file diff --git a/pkg/provider/modelmgr/apis/moonshotchatcmpl.py b/pkg/provider/modelmgr/apis/moonshotchatcmpl.py index cb9fd934..f50ca628 100644 --- a/pkg/provider/modelmgr/apis/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/apis/moonshotchatcmpl.py @@ -3,7 +3,10 @@ from __future__ import annotations from ....core import app from . import chatcmpl -from .. import api +from .. import api, entities, errors +from ....core import entities as core_entities, app +from ... import entities as llm_entities +from ...tools import entities as tools_entities @api.requester_class("moonshot-chat-completions") @@ -13,3 +16,41 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): def __init__(self, ap: app.Application): self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions'] self.ap = ap + + async def _closure( + self, + req_messages: list[dict], + use_model: entities.LLMModelInfo, + use_funcs: list[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = use_model.token_mgr.get_token() + + args = self.requester_cfg['args'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + if use_model.tool_call_supported: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args["tools"] = tools + + # 设置此次请求中的messages + messages = req_messages + + # deepseek 不支持多模态,把content都转换成纯文字 + for m in messages: + if isinstance(m["content"], list): + m["content"] = " ".join([c["text"] for c in m["content"]]) + + # 删除空的 + messages = [m for m in messages if m["content"].strip() != ""] + + args["messages"] = messages + + # 发送请求 + resp = await self._req(args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message \ No newline at end of file diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 3fffd784..93cf54dd 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -37,6 +37,10 @@ class ModelManager: raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置") async def initialize(self): + + # 检查是否启用了vision但是没有配置oss + if self.ap.provider_cfg.data['enable-vision'] and not self.ap.oss_mgr.available(): + self.ap.logger.warn("启用了视觉但是没有配置可用的oss服务,基于 URL 传递图片的视觉 API 将无法正常使用") # 初始化token_mgr, requester for k, v in self.ap.provider_cfg.data['keys'].items(): diff --git a/requirements.txt b/requirements.txt index f04bdc9a..63996e0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,6 @@ aiohttp pydantic websockets urllib3 -psutil \ No newline at end of file +psutil + +oss2 \ No newline at end of file diff --git a/templates/metadata/llm-models.json b/templates/metadata/llm-models.json index 9787223e..13cf93c2 100644 --- a/templates/metadata/llm-models.json +++ b/templates/metadata/llm-models.json @@ -22,6 +22,10 @@ "name": "gpt-4-32k", "tool_call_supported": true }, + { + "name": "gpt-4o", + "tool_call_supported": true + }, { "model_name": "SparkDesk", "name": "OneAPI/SparkDesk" diff --git a/templates/provider.json b/templates/provider.json index e537156b..dadec8ea 100644 --- a/templates/provider.json +++ b/templates/provider.json @@ -1,5 +1,6 @@ { "enable-chat": true, + "enable-vision": false, "keys": { "openai": [ "sk-1234567890" diff --git a/templates/system.json b/templates/system.json index 72d29b98..906640c7 100644 --- a/templates/system.json +++ b/templates/system.json @@ -1,5 +1,17 @@ { "admin-sessions": [], + "oss": [ + { + "type": "aliyun", + "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", + "public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com", + "access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5", + "access-key-secret": "xxxxxx", + "bucket": "qchatgpt", + "prefix": "qchatgpt", + "enable": false + } + ], "network-proxies": { "http": null, "https": null