From 7db56c8e7708f6195cbfbaefdaba66c4fc369b16 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 22 May 2024 20:09:29 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20claude=20=E6=94=AF=E6=8C=81=E8=A7=86?= =?UTF-8?q?=E8=A7=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/provider/modelmgr/apis/anthropicmsgs.py | 43 +++++--- pkg/provider/tools/toolmgr.py | 104 +++++++++++++------- 2 files changed, 96 insertions(+), 51 deletions(-) diff --git a/pkg/provider/modelmgr/apis/anthropicmsgs.py b/pkg/provider/modelmgr/apis/anthropicmsgs.py index 6c18bf51..a9a3f05d 100644 --- a/pkg/provider/modelmgr/apis/anthropicmsgs.py +++ b/pkg/provider/modelmgr/apis/anthropicmsgs.py @@ -11,6 +11,7 @@ from .. import api, entities, errors from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities +from ....utils import image @api.requester_class("anthropic-messages") @@ -54,29 +55,45 @@ class AnthropicMessages(api.LLMAPIRequester): and isinstance(system_role_message.content, str): args['system'] = system_role_message.content - # 其他消息 - # req_messages = [ - # m.dict(exclude_none=True) for m in messages \ - # if (isinstance(m.content, str) and m.content.strip() != "") \ - # or (isinstance(m.content, list) and ) - # ] - # 暂时不支持vision,仅保留纯文字的content req_messages = [] for m in messages: if isinstance(m.content, str) and m.content.strip() != "": req_messages.append(m.dict(exclude_none=True)) elif isinstance(m.content, list): - # 删除m.content中的type!=text的元素 - m.content = [ - c for c in m.content if c.type == "text" - ] + # m.content = [ + # c for c in m.content if c.type == "text" + # ] - if len(m.content) > 0: - req_messages.append(m.dict(exclude_none=True)) + # if len(m.content) > 0: + # req_messages.append(m.dict(exclude_none=True)) + + msg_dict = m.dict(exclude_none=True) + + for i, ce in enumerate(m.content): + if ce.type == "image_url": + alter_image_ele = { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": await image.qq_image_url_to_base64(ce.image_url.url) + } + } + msg_dict["content"][i] = alter_image_ele + + req_messages.append(msg_dict) args["messages"] = req_messages + # anthropic的tools处在beta阶段,sdk不稳定,故暂时不支持 + # + # if funcs: + # tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) + + # if tools: + # args["tools"] = tools + try: resp = await self.client.messages.create(**args) diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index d233b5e3..5e780c50 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -9,11 +9,10 @@ from ...plugin import context as plugin_context class ToolManager: - """LLM工具管理器 - """ + """LLM工具管理器""" ap: app.Application - + def __init__(self, ap: app.Application): self.ap = ap self.all_functions = [] @@ -22,35 +21,33 @@ class ToolManager: pass async def get_function(self, name: str) -> entities.LLMFunction: - """获取函数 - """ + """获取函数""" for function in await self.get_all_functions(): if function.name == name: return function return None - - async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]: - """获取函数和插件 - """ + + async def get_function_and_plugin( + self, name: str + ) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]: + """获取函数和插件""" for plugin in self.ap.plugin_mgr.plugins: for function in plugin.content_functions: if function.name == name: return function, plugin.plugin_inst return None, None - + async def get_all_functions(self) -> list[entities.LLMFunction]: - """获取所有函数 - """ + """获取所有函数""" all_functions: list[entities.LLMFunction] = [] - + for plugin in self.ap.plugin_mgr.plugins: all_functions.extend(plugin.content_functions) - + return all_functions - async def generate_tools_for_openai(self, use_funcs: entities.LLMFunction) -> str: - """生成函数列表 - """ + async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list: + """生成函数列表""" tools = [] for function in use_funcs: @@ -60,40 +57,71 @@ class ToolManager: "function": { "name": function.name, "description": function.description, - "parameters": function.parameters - } + "parameters": function.parameters, + }, + } + tools.append(function_schema) + + return tools + + async def generate_tools_for_anthropic( + self, use_funcs: list[entities.LLMFunction] + ) -> list: + """为anthropic生成函数列表 + + e.g. + + [ + { + "name": "get_stock_price", + "description": "Get the current stock price for a given ticker symbol.", + "input_schema": { + "type": "object", + "properties": { + "ticker": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL for Apple Inc." + } + }, + "required": ["ticker"] + } + } + ] + """ + + tools = [] + + for function in use_funcs: + if function.enable: + function_schema = { + "name": function.name, + "description": function.description, + "input_schema": function.parameters, } tools.append(function_schema) return tools async def execute_func_call( - self, - query: core_entities.Query, - name: str, - parameters: dict + self, query: core_entities.Query, name: str, parameters: dict ) -> typing.Any: - """执行函数调用 - """ + """执行函数调用""" try: function, plugin = await self.get_function_and_plugin(name) if function is None: return None - + parameters = parameters.copy() - parameters = { - "query": query, - **parameters - } - + parameters = {"query": query, **parameters} + return await function.func(plugin, **parameters) except Exception as e: - self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') + self.ap.logger.error(f"执行函数 {name} 时发生错误: {e}") traceback.print_exc() - return f'error occurred when executing function {name}: {e}' + return f"error occurred when executing function {name}: {e}" finally: plugin = None @@ -107,11 +135,11 @@ class ToolManager: await self.ap.ctr_mgr.usage.post_function_record( plugin={ - 'name': plugin.plugin_name, - 'remote': plugin.plugin_source, - 'version': plugin.plugin_version, - 'author': plugin.plugin_author + "name": plugin.plugin_name, + "remote": plugin.plugin_source, + "version": plugin.plugin_version, + "author": plugin.plugin_author, }, function_name=function.name, function_description=function.description, - ) \ No newline at end of file + )