feat: 模型视觉多模态支持

This commit is contained in:
RockChinQ
2024-05-15 21:40:18 +08:00
parent 8807f02f36
commit d5b5d667a5
32 changed files with 596 additions and 72 deletions

View File

@@ -13,11 +13,16 @@ class CommandReturn(pydantic.BaseModel):
"""命令返回值
"""
text: typing.Optional[str]
text: typing.Optional[str] = None
"""文本
"""
image: typing.Optional[mirai.Image]
image: typing.Optional[mirai.Image] = None
"""弃用"""
image_url: typing.Optional[str] = None
"""图片链接
"""
error: typing.Optional[errors.CommandError]= None
"""错误

View File

@@ -27,7 +27,7 @@ class JSONConfigFile(file_model.ConfigFile):
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self) -> dict:
async def load(self, completion: bool=True) -> dict:
if not self.exists():
await self.create()
@@ -39,9 +39,11 @@ class JSONConfigFile(file_model.ConfigFile):
with open(self.config_file_name, "r", encoding="utf-8") as f:
cfg = json.load(f)
for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
if completion:
for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
return cfg

View File

@@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self) -> dict:
async def load(self, completion: bool=True) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name)
@@ -43,18 +43,19 @@ class PythonModuleConfigFile(file_model.ConfigFile):
cfg[key] = getattr(module, key)
# 从模板模块文件中进行补全
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
module = importlib.import_module(module_name)
if completion:
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
module = importlib.import_module(module_name)
for key in dir(module):
if key.startswith('__'):
continue
for key in dir(module):
if key.startswith('__'):
continue
if not isinstance(getattr(module, key), allowed_types):
continue
if not isinstance(getattr(module, key), allowed_types):
continue
if key not in cfg:
cfg[key] = getattr(module, key)
if key not in cfg:
cfg[key] = getattr(module, key)
return cfg

View File

@@ -20,8 +20,8 @@ class ConfigManager:
self.file = cfg_file
self.data = {}
async def load_config(self):
self.data = await self.file.load()
async def load_config(self, completion: bool=True):
self.data = await self.file.load(completion=completion)
async def dump_config(self):
await self.file.save(self.data)
@@ -30,7 +30,7 @@ class ConfigManager:
self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
"""加载Python模块配置文件"""
cfg_inst = pymodule.PythonModuleConfigFile(
config_name,
@@ -38,12 +38,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config()
await cfg_mgr.load_config(completion=completion)
return cfg_mgr
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager:
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
"""加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile(
config_name,
@@ -52,6 +52,6 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config()
await cfg_mgr.load_config(completion=completion)
return cfg_mgr

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("vision-and-oss-config", 6)
class VisionAndOSSConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "enable-vision" not in self.ap.provider_cfg.data \
or "oss" not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""
if "enable-vision" not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data["enable-vision"] = False
if "oss" not in self.ap.system_cfg.data:
self.ap.system_cfg.data["oss"] = [
{
"type": "aliyun",
"endpoint": "https://oss-cn-hangzhou.aliyuncs.com",
"public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com",
"access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5",
"access-key-secret": "xxxxxx",
"bucket": "qchatgpt",
"prefix": "qchatgpt",
"enable": False,
}
]
await self.ap.provider_cfg.dump_config()
await self.ap.system_cfg.dump_config()

View File

@@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def load(self) -> dict:
async def load(self, completion: bool=True) -> dict:
pass
@abc.abstractmethod

View File

@@ -15,6 +15,7 @@ from ..command import cmdmgr
from ..plugin import manager as plugin_mgr
from ..pipeline import pool
from ..pipeline import controller, stagemgr
from ..oss import oss
from ..utils import version as version_mgr, proxy as proxy_mgr
@@ -71,6 +72,8 @@ class Application:
proxy_mgr: proxy_mgr.ProxyManager = None
oss_mgr: oss.OSSServiceManager = None
logger: logging.Logger = None
def __init__(self):

View File

@@ -14,6 +14,7 @@ required_deps = {
"yaml": "pyyaml",
"aiohttp": "aiohttp",
"psutil": "psutil",
"oss2": "oss2",
}

View File

@@ -14,6 +14,7 @@ from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr
from ...oss import oss as oss_mgr
@stage.stage_class("BuildAppStage")
@@ -68,6 +69,10 @@ class BuildAppStage(stage.BootingStage):
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
oss_mgr_inst = oss_mgr.OSSServiceManager(ap)
await oss_mgr_inst.initialize()
ap.oss_mgr = oss_mgr_inst
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst
@@ -83,7 +88,6 @@ class BuildAppStage(stage.BootingStage):
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst
@@ -92,5 +96,6 @@ class BuildAppStage(stage.BootingStage):
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap)
ap.ctrl = ctrl

View File

@@ -12,11 +12,11 @@ class LoadConfigStage(stage.BootingStage):
async def run(self, ap: app.Application):
"""启动
"""
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False)
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False)
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False)
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False)
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False)
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
await ap.plugin_setting_meta.dump_config()

View File

@@ -5,7 +5,7 @@ import importlib
from .. import stage, app
from ...config import migration
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ...config.migrations import m005_deepseek_cfg_completion
from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_and_oss_config
@stage.stage_class("MigrationStage")

0
pkg/oss/__init__.py Normal file
View File

85
pkg/oss/oss.py Normal file
View File

@@ -0,0 +1,85 @@
from __future__ import annotations
import aiohttp
import typing
from urllib.parse import urlparse, parse_qs
import ssl
from . import service as osssv
from ..core import app
from .services import aliyun
class OSSServiceManager:
ap: app.Application
service: osssv.OSSService = None
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化
"""
mapping = {}
for svcls in osssv.preregistered_services:
mapping[svcls.name] = svcls
for sv in self.ap.system_cfg.data['oss']:
if sv['enable']:
if sv['type'] not in mapping:
raise Exception(f"未知的OSS服务类型: {sv['type']}")
self.service = mapping[sv['type']](self.ap, sv)
await self.service.initialize()
break
def available(self) -> bool:
"""是否可用
Returns:
bool: 是否可用
"""
return self.service is not None
async def fetch_image(self, image_url: str) -> bytes:
parsed = urlparse(image_url)
query = parse_qs(parsed.query)
# Flatten the query dictionary
query = {k: v[0] for k, v in query.items()}
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(
f"http://{parsed.netloc}{parsed.path}",
params=query,
ssl=ssl_context
) as resp:
resp.raise_for_status() # 检查HTTP错误
file_bytes = await resp.read()
return file_bytes
async def upload_url_image(
self,
image_url: str,
) -> str:
"""上传URL图片
Args:
image_url (str): 图片URL
Returns:
str: 文件URL
"""
file_bytes = await self.fetch_image(image_url)
return await self.service.upload(file_bytes=file_bytes, ext=".jpg")

67
pkg/oss/service.py Normal file
View File

@@ -0,0 +1,67 @@
from __future__ import annotations
import typing
import abc
from ..core import app
preregistered_services: list[typing.Type[OSSService]] = []
def service_class(
name: str
) -> typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]:
"""OSS服务类装饰器
Args:
name (str): 服务名称
Returns:
typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]: 装饰器
"""
def decorator(cls: typing.Type[OSSService]) -> typing.Type[OSSService]:
assert issubclass(cls, OSSService)
cls.name = name
preregistered_services.append(cls)
return cls
return decorator
class OSSService(metaclass=abc.ABCMeta):
"""OSS抽象类"""
name: str
ap: app.Application
cfg: dict
def __init__(self, ap: app.Application, cfg: dict) -> None:
self.ap = ap
self.cfg = cfg
async def initialize(self):
pass
@abc.abstractmethod
async def upload(
self,
local_file: str=None,
file_bytes: bytes=None,
ext: str=None,
) -> str:
"""上传文件
Args:
local_file (str, optional): 本地文件路径. Defaults to None.
file_bytes (bytes, optional): 文件字节. Defaults to None.
ext (str, optional): 文件扩展名. Defaults to None.
Returns:
str: 文件URL
"""
pass

View File

View File

@@ -0,0 +1,48 @@
from __future__ import annotations
import uuid
import oss2
from .. import service as osssv
@osssv.service_class('aliyun')
class AliyunOSSService(osssv.OSSService):
"""阿里云OSS服务"""
auth: oss2.Auth
bucket: oss2.Bucket
async def initialize(self):
self.auth = oss2.Auth(
self.cfg['access-key-id'],
self.cfg['access-key-secret']
)
self.bucket = oss2.Bucket(
self.auth,
self.cfg['endpoint'],
self.cfg['bucket']
)
async def upload(
self,
local_file: str=None,
file_bytes: bytes=None,
ext: str=None,
) -> str:
if local_file is not None:
with open(local_file, 'rb') as f:
file_bytes = f.read()
if file_bytes is None:
raise Exception("缺少文件内容")
name = str(uuid.uuid1())
key = f"{self.cfg['prefix']}/{name}{ext}"
self.bucket.put_object(key, file_bytes)
return f"{self.cfg['public-read-base-url']}/{key}"

View File

@@ -9,6 +9,7 @@ from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities
@stage.stage_class('PostContentFilterStage')
@@ -141,6 +142,21 @@ class ContentFilterStage(stage.PipelineStage):
"""处理
"""
if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
for me in query.message_chain:
if not isinstance(me, mirai.Plain):
contain_non_text = True
break
if contain_non_text:
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
return await self._pre_process(
str(query.message_chain).strip(),
query

View File

@@ -4,6 +4,8 @@ import enum
import pydantic
from ...provider import entities as llm_entities
class ResultLevel(enum.Enum):
"""结果等级"""
@@ -29,6 +31,13 @@ class EnableStage(enum.Enum):
"""后处理"""
class AcceptContent(enum.Enum):
"""过滤器接受的内容模态"""
TEXT = enum.auto()
IMAGE_URL = enum.auto()
class FilterResult(pydantic.BaseModel):
level: ResultLevel
"""结果等级
@@ -38,7 +47,7 @@ class FilterResult(pydantic.BaseModel):
"""
replacement: str
"""替换后的消息
"""替换后的文本消息
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
若没有修改内容,也需要返回原消息。

View File

@@ -5,6 +5,7 @@ import typing
from ...core import app
from . import entities
from ...provider import entities as llm_entities
preregistered_filters: list[typing.Type[ContentFilter]] = []
@@ -56,6 +57,16 @@ class ContentFilter(metaclass=abc.ABCMeta):
entities.EnableStage.PRE,
entities.EnableStage.POST
]
@property
def accept_content(self):
"""本过滤器接受的模态
默认仅接受纯文本
"""
return [
entities.AcceptContent.TEXT
]
async def initialize(self):
"""初始化过滤器
@@ -63,7 +74,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult:
async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
"""处理消息
分为前后阶段,具体取决于 enable_stages 的值。
@@ -71,6 +82,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
Args:
message (str): 需要检查的内容
image_url (str): 要检查的图片的 URL
Returns:
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import mirai
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...provider import entities as llm_entities
@@ -37,9 +39,31 @@ class PreProcessor(stage.PipelineStage):
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
# 检查vision是否启用没启用就删除所有图片
if not self.ap.provider_cfg.data['enable-vision']:
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
if me.type == 'image_url':
msg.content.remove(me)
content_list = []
for me in query.message_chain:
if isinstance(me, mirai.Plain):
content_list.append(
llm_entities.ContentElement.from_text(me.text)
)
elif isinstance(me, mirai.Image):
if self.ap.provider_cfg.data['enable-vision']:
if me.url is not None:
content_list.append(
llm_entities.ContentElement.from_image_url(str(me.url))
)
query.user_message = llm_entities.Message( # TODO 适配多模态输入
role='user',
content=str(query.message_chain).strip()
content=content_list
)
query.use_model = conversation.use_model

View File

@@ -93,15 +93,28 @@ class CommandHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif ret.text is not None:
elif ret.text is not None or ret.image_url is not None:
content: list[llm_entities.ContentElement]= []
if ret.text is not None:
content.append(
llm_entities.ContentElement.from_text(ret.text)
)
if ret.image_url is not None:
content.append(
llm_entities.ContentElement.from_image_url(ret.image_url)
)
query.resp_messages.append(
llm_entities.Message(
role='command',
content=ret.text,
content=content,
)
)
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}')
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@@ -34,17 +34,19 @@ class ResponseWrapper(stage.PipelineStage):
"""
if query.resp_messages[-1].role == 'command':
query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
else:
query.resp_message_chain.append(query.resp_messages[-1].content)
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
# else:
# query.resp_message_chain.append(query.resp_messages[-1].content)
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@@ -59,7 +61,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = ''
if result.content is not None: # 有内容
reply_text = result.content
reply_text = str(result.get_content_mirai_message_chain())
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
@@ -87,7 +89,7 @@ class ResponseWrapper(stage.PipelineStage):
else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
query.resp_message_chain.append(result.get_content_mirai_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@@ -21,14 +21,34 @@ class ToolCall(pydantic.BaseModel):
function: FunctionCall
class Content(pydantic.BaseModel):
class ImageURLContentObject(pydantic.BaseModel):
url: str
class ContentElement(pydantic.BaseModel):
type: str
"""内容类型"""
text: typing.Optional[str] = None
image_url: typing.Optional[str] = None
image_url: typing.Optional[ImageURLContentObject] = None
def __str__(self):
if self.type == 'text':
return self.text
elif self.type == 'image_url':
return f'[图片]({self.image_url})'
else:
return '未知内容'
@classmethod
def from_text(cls, text: str):
return cls(type='text', text=text)
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
class Message(pydantic.BaseModel):
@@ -40,7 +60,7 @@ class Message(pydantic.BaseModel):
name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置"""
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None
content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
"""内容"""
tool_calls: typing.Optional[list[ToolCall]] = None
@@ -50,8 +70,38 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str:
if self.content is not None:
return str(self.role) + ": " + str(self.content)
return str(self.role) + ": " + str(self.get_content_mirai_message_chain())
elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}'
else:
return '未知消息'
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
"""将内容转换为 Mirai MessageChain 对象
Args:
prefix_text (str): 首个文字组件的前缀文本
"""
if self.content is None:
return None
elif isinstance(self.content, str):
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
elif isinstance(self.content, list):
mc = []
for ce in self.content:
if ce.type == 'text':
mc.append(mirai.Plain(ce.text))
elif ce.type == 'image':
mc.append(mirai.Image(url=ce.image_url))
# 找第一个文字组件
if prefix_text:
for i, c in enumerate(mc):
if isinstance(c, mirai.Plain):
mc[i] = mirai.Plain(prefix_text+c.text)
break
else:
mc.insert(0, mirai.Plain(prefix_text))
return mirai.MessageChain(mc)

View File

@@ -38,30 +38,42 @@ class AnthropicMessages(api.LLMAPIRequester):
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = model.name if model.model_name is None else model.model_name
req_messages = [
m.dict(exclude_none=True) for m in messages if m.content.strip() != ""
]
# 处理消息
# 删除所有 role=system & content='' 的消息
req_messages = [
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]
# system
system_role_message = None
# 检查是否有 role=system 的消息,若有,改为 role=user并在后面加一个 role=assistant 的消息
system_role_index = []
for i, m in enumerate(req_messages):
if m["role"] == "system":
system_role_index.append(i)
m["role"] = "user"
for i, m in enumerate(messages):
if m.role == "system":
system_role_message = m
if system_role_index:
for i in system_role_index[::-1]:
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
messages.pop(i)
break
# 忽略掉空消息,用户可能发送空消息,而上层未过滤
req_messages = [
m for m in req_messages if m["content"].strip() != ""
]
if isinstance(system_role_message, llm_entities.Message) \
and isinstance(system_role_message.content, str):
args['system'] = system_role_message.content
# 其他消息
# req_messages = [
# m.dict(exclude_none=True) for m in messages \
# if (isinstance(m.content, str) and m.content.strip() != "") \
# or (isinstance(m.content, list) and )
# ]
# 暂时不支持vision仅保留纯文字的content
req_messages = []
for m in messages:
if isinstance(m.content, str) and m.content.strip() != "":
req_messages.append(m.dict(exclude_none=True))
elif isinstance(m.content, list):
# 删除m.content中的type!=text的元素
m.content = [
c for c in m.content if c.get("type") == "text"
]
if len(m.content) > 0:
req_messages.append(m.dict(exclude_none=True))
args["messages"] = req_messages

View File

@@ -23,9 +23,18 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
requester_cfg: dict
cached_image_oss_url: dict[str, str] = {}
"""缓存的OSS服务的图片URL
key: 前文message中的原图片URLQQ图片
value: OSS服务的图片URL
"""
def __init__(self, ap: app.Application):
self.ap = ap
self.cached_image_oss_url = {}
self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions']
async def initialize(self):
@@ -74,7 +83,16 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
messages = req_messages.copy()
# 检查vision
if self.ap.oss_mgr.available():
for msg in messages:
if isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_url":
me["image_url"]['url'] = await self.get_oss_url(me["image_url"]['url'])
args["messages"] = messages
# 发送请求
@@ -112,3 +130,17 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')
async def get_oss_url(
self,
original_url: str,
) -> str:
if original_url in self.cached_image_oss_url:
return self.cached_image_oss_url[original_url]
oss_url = await self.ap.oss_mgr.upload_url_image(original_url)
self.cached_image_oss_url[original_url] = oss_url
return oss_url

View File

@@ -3,7 +3,10 @@ from __future__ import annotations
from ....core import app
from . import chatcmpl
from .. import api
from .. import api, entities, errors
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("deepseek-chat-completions")
@@ -12,4 +15,39 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions']
self.ap = ap
self.ap = ap
async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_model.tool_call_supported:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
args["messages"] = messages
# 发送请求
resp = await self._req(args)
# 处理请求结果
message = await self._make_msg(resp)
return message

View File

@@ -3,7 +3,10 @@ from __future__ import annotations
from ....core import app
from . import chatcmpl
from .. import api
from .. import api, entities, errors
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("moonshot-chat-completions")
@@ -13,3 +16,41 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
self.ap = ap
async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_model.tool_call_supported:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
# 删除空的
messages = [m for m in messages if m["content"].strip() != ""]
args["messages"] = messages
# 发送请求
resp = await self._req(args)
# 处理请求结果
message = await self._make_msg(resp)
return message

View File

@@ -37,6 +37,10 @@ class ModelManager:
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
async def initialize(self):
# 检查是否启用了vision但是没有配置oss
if self.ap.provider_cfg.data['enable-vision'] and not self.ap.oss_mgr.available():
self.ap.logger.warn("启用了视觉但是没有配置可用的oss服务基于 URL 传递图片的视觉 API 将无法正常使用")
# 初始化token_mgr, requester
for k, v in self.ap.provider_cfg.data['keys'].items():

View File

@@ -13,4 +13,6 @@ aiohttp
pydantic
websockets
urllib3
psutil
psutil
oss2

View File

@@ -22,6 +22,10 @@
"name": "gpt-4-32k",
"tool_call_supported": true
},
{
"name": "gpt-4o",
"tool_call_supported": true
},
{
"model_name": "SparkDesk",
"name": "OneAPI/SparkDesk"

View File

@@ -1,5 +1,6 @@
{
"enable-chat": true,
"enable-vision": false,
"keys": {
"openai": [
"sk-1234567890"

View File

@@ -1,5 +1,17 @@
{
"admin-sessions": [],
"oss": [
{
"type": "aliyun",
"endpoint": "https://oss-cn-hangzhou.aliyuncs.com",
"public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com",
"access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5",
"access-key-secret": "xxxxxx",
"bucket": "qchatgpt",
"prefix": "qchatgpt",
"enable": false
}
],
"network-proxies": {
"http": null,
"https": null