diff --git a/pkg/core/app.py b/pkg/core/app.py index c0ce12fa..c9d06e15 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -8,6 +8,7 @@ from ..openai import manager as openai_mgr from ..openai.session import sessionmgr as llm_session_mgr from ..openai.requester import modelmgr as llm_model_mgr from ..openai.sysprompt import sysprompt as llm_prompt_mgr +from ..openai.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr from ..database import manager as database_mgr from ..utils.center import v2 as center_mgr @@ -27,6 +28,8 @@ class Application: prompt_mgr: llm_prompt_mgr.PromptManager = None + tool_mgr: llm_tool_mgr.ToolManager = None + cfg_mgr: config_mgr.ConfigManager = None tips_mgr: config_mgr.ConfigManager = None @@ -46,10 +49,21 @@ class Application: def __init__(self): pass - async def run(self): - # TODO make it async + async def initialize(self): plugin_host.initialize_plugins() + # 把现有的所有内容函数加到toolmgr里 + for func in plugin_host.__callable_functions__: + print(func) + self.tool_mgr.register_legacy_function( + name=func['name'], + description=func['description'], + parameters=func['parameters'], + func=plugin_host.__function_inst_map__[func['name']] + ) + + async def run(self): + tasks = [ asyncio.create_task(self.im_mgr.run()), asyncio.create_task(self.ctrl.run()) diff --git a/pkg/core/boot.py b/pkg/core/boot.py index a74615ec..c06cc6cd 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -18,6 +18,7 @@ from ..openai import manager as llm_mgr from ..openai.session import sessionmgr as llm_session_mgr from ..openai.requester import modelmgr as llm_model_mgr from ..openai.sysprompt import sysprompt as llm_prompt_mgr +from ..openai.tools import toolmgr as llm_tool_mgr from ..openai import dprompt as llm_dprompt from ..qqbot import manager as im_mgr from ..qqbot.cmds import aamgr as im_cmd_aamgr @@ -127,6 +128,10 @@ async def make_app() -> app.Application: await llm_prompt_mgr_inst.initialize() ap.prompt_mgr = llm_prompt_mgr_inst + llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) + await llm_tool_mgr_inst.initialize() + ap.tool_mgr = llm_tool_mgr_inst + im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap) await im_mgr_inst.initialize() ap.im_mgr = im_mgr_inst @@ -140,7 +145,8 @@ async def make_app() -> app.Application: # TODO make it async plugin_host.load_plugins() - # plugin_host.initialize_plugins() + + await ap.initialize() return ap diff --git a/pkg/openai/entities.py b/pkg/openai/entities.py index 58f48d95..2dd5804b 100644 --- a/pkg/openai/entities.py +++ b/pkg/openai/entities.py @@ -5,27 +5,29 @@ import enum import pydantic -class MessageRole(enum.Enum): - - SYSTEM = 'system' - - USER = 'user' - - ASSISTANT = 'assistant' - - FUNCTION = 'function' - - class FunctionCall(pydantic.BaseModel): name: str - args: dict[str, typing.Any] + arguments: str + + +class ToolCall(pydantic.BaseModel): + id: str + + type: str + + function: FunctionCall class Message(pydantic.BaseModel): + role: str - role: MessageRole + name: typing.Optional[str] = None content: typing.Optional[str] = None function_call: typing.Optional[FunctionCall] = None + + tool_calls: typing.Optional[list[ToolCall]] = None + + tool_call_id: typing.Optional[str] = None diff --git a/pkg/openai/requester/apis/chatcmpl.py b/pkg/openai/requester/apis/chatcmpl.py index 5b6d2297..24ff2d7e 100644 --- a/pkg/openai/requester/apis/chatcmpl.py +++ b/pkg/openai/requester/apis/chatcmpl.py @@ -2,8 +2,10 @@ from __future__ import annotations import asyncio import typing +import json import openai +import openai.types.chat.chat_completion as chat_completion from .. import api from ....core import entities as core_entities @@ -12,21 +14,127 @@ from ...session import entities as session_entities class OpenAIChatCompletion(api.LLMAPIRequester): - - client: openai.Client + client: openai.AsyncClient async def initialize(self): - self.client = openai.Client( - base_url=self.ap.cfg_mgr.data['openai_config']['reverse_proxy'], - timeout=self.ap.cfg_mgr.data['process_message_timeout'] + self.client = openai.AsyncClient( + api_key="", + base_url=self.ap.cfg_mgr.data["openai_config"]["reverse_proxy"], + timeout=self.ap.cfg_mgr.data["process_message_timeout"], ) - async def request(self, query: core_entities.Query, conversation: session_entities.Conversation) -> typing.AsyncGenerator[llm_entities.Message, None]: - """请求 - """ - await asyncio.sleep(10) + async def _req( + self, + args: dict, + ) -> chat_completion.ChatCompletion: + self.ap.logger.debug(f"req chat_completion with args {args}") + return await self.client.chat.completions.create(**args) - yield llm_entities.Message( - role=llm_entities.MessageRole.ASSISTANT, - content="hello" - ) + async def _make_msg( + self, + chat_completion: chat_completion.ChatCompletion, + ) -> llm_entities.Message: + chatcmpl_message = chat_completion.choices[0].message.dict() + + message = llm_entities.Message(**chatcmpl_message) + + return message + + async def _closure( + self, + req_messages: list[dict], + conversation: session_entities.Conversation, + user_text: str = None, + function_ret: str = None, + ) -> llm_entities.Message: + self.client.api_key = conversation.use_model.token_mgr.get_token() + + args = self.ap.cfg_mgr.data["completion_api_params"].copy() + args["model"] = conversation.use_model.name + + tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation) + # tools = [ + # { + # "type": "function", + # "function": { + # "name": "get_current_weather", + # "description": "Get the current weather in a given location", + # "parameters": { + # "type": "object", + # "properties": { + # "location": { + # "type": "string", + # "description": "The city and state, e.g. San Francisco, CA", + # }, + # "unit": { + # "type": "string", + # "enum": ["celsius", "fahrenheit"], + # }, + # }, + # "required": ["location"], + # }, + # }, + # } + # ] + if tools: + args["tools"] = tools + + # 设置此次请求中的messages + messages = req_messages + args["messages"] = messages + + # 发送请求 + resp = await self._req(args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message + + async def request( + self, query: core_entities.Query, conversation: session_entities.Conversation + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """请求""" + + pending_tool_calls = [] + + req_messages = [ + m.dict(exclude_none=True) for m in conversation.prompt.messages + ] + [m.dict(exclude_none=True) for m in conversation.messages] + + req_messages.append({"role": "user", "content": str(query.message_chain)}) + + msg = await self._closure(req_messages, conversation) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg.dict(exclude_none=True)) + + while pending_tool_calls: + for tool_call in pending_tool_calls: + func = tool_call.function + + parameters = json.loads(func.arguments) + + func_ret = await self.ap.tool_mgr.execute_func_call( + query, func.name, parameters + ) + + msg = llm_entities.Message( + role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id + ) + + yield msg + + req_messages.append(msg.dict(exclude_none=True)) + + # 处理完所有调用,继续请求 + msg = await self._closure(req_messages, conversation) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg.dict(exclude_none=True)) diff --git a/pkg/openai/requester/modelmgr.py b/pkg/openai/requester/modelmgr.py index cc606b03..7e6a3b52 100644 --- a/pkg/openai/requester/modelmgr.py +++ b/pkg/openai/requester/modelmgr.py @@ -19,7 +19,8 @@ class ModelManager: async def initialize(self): openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap) - openai_token_mgr = token.TokenManager(self.ap, self.ap.cfg_mgr.data['openai_config']['api_key'].values()) + await openai_chat_completion.initialize() + openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values())) self.model_list.append( entities.LLMModelInfo( diff --git a/pkg/openai/session/entities.py b/pkg/openai/session/entities.py index 49ddb845..cbeb72a3 100644 --- a/pkg/openai/session/entities.py +++ b/pkg/openai/session/entities.py @@ -10,6 +10,7 @@ from ..sysprompt import entities as sysprompt_entities from .. import entities as llm_entities from ..requester import entities from ...core import entities as core_entities +from ..tools import entities as tools_entities class Conversation(pydantic.BaseModel): @@ -25,6 +26,8 @@ class Conversation(pydantic.BaseModel): use_model: entities.LLMModelInfo + use_funcs: typing.Optional[list[tools_entities.LLMFunction]] + class Session(pydantic.BaseModel): """会话""" diff --git a/pkg/openai/session/sessionmgr.py b/pkg/openai/session/sessionmgr.py index 8aff6e02..a1d5d4d9 100644 --- a/pkg/openai/session/sessionmgr.py +++ b/pkg/openai/session/sessionmgr.py @@ -29,7 +29,7 @@ class SessionManager: session = entities.Session( launcher_type=query.launcher_type, launcher_id=query.launcher_id, - semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000) + semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000), ) self.session_list.append(session) return session @@ -43,6 +43,7 @@ class SessionManager: prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), messages=[], use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']), + use_funcs=await self.ap.tool_mgr.get_all_functions(), ) session.conversations.append(conversation) session.using_conversation = conversation diff --git a/pkg/openai/sysprompt/loaders/scenario.py b/pkg/openai/sysprompt/loaders/scenario.py index e0c2bd33..4d54f30f 100644 --- a/pkg/openai/sysprompt/loaders/scenario.py +++ b/pkg/openai/sysprompt/loaders/scenario.py @@ -21,14 +21,9 @@ class ScenarioPromptLoader(loader.PromptLoader): file_json = json.loads(file_str) messages = [] for msg in file_json["prompt"]: - role = llm_entities.MessageRole.SYSTEM + role = 'system' if "role" in msg: - if msg["role"] == "user": - role = llm_entities.MessageRole.USER - elif msg["role"] == "system": - role = llm_entities.MessageRole.SYSTEM - elif msg["role"] == "function": - role = llm_entities.MessageRole.FUNCTION + role = msg['role'] messages.append( llm_entities.Message( role=role, diff --git a/pkg/openai/sysprompt/loaders/single.py b/pkg/openai/sysprompt/loaders/single.py index ad37d878..1fff5a69 100644 --- a/pkg/openai/sysprompt/loaders/single.py +++ b/pkg/openai/sysprompt/loaders/single.py @@ -19,7 +19,7 @@ class SingleSystemPromptLoader(loader.PromptLoader): name=name, messages=[ llm_entities.Message( - role=llm_entities.MessageRole.SYSTEM, + role='system', content=cnt ) ] @@ -34,7 +34,7 @@ class SingleSystemPromptLoader(loader.PromptLoader): name=file_name, messages=[ llm_entities.Message( - role=llm_entities.MessageRole.SYSTEM, + role='system', content=file_str ) ] diff --git a/pkg/openai/tools/__init__.py b/pkg/openai/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/openai/tools/entities.py b/pkg/openai/tools/entities.py new file mode 100644 index 00000000..b79627e5 --- /dev/null +++ b/pkg/openai/tools/entities.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import abc +import typing +import asyncio + +import pydantic + + +class LLMFunction(pydantic.BaseModel): + """函数""" + + name: str + """函数名""" + + human_desc: str + + description: str + """给LLM识别的函数描述""" + + enable: typing.Optional[bool] = True + + parameters: dict + + func: typing.Callable + """供调用的python异步方法 + + 此异步方法第一个参数接收当前请求的query对象,可以从其中取出session等信息。 + query参数不在parameters中,但在调用时会自动传入。 + 但在当前版本中,插件提供的内容函数都是同步的,且均为请求无关的,故在此版本的实现(以及考虑了向后兼容性的版本)中, + 对插件的内容函数进行封装并存到这里来。 + """ + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/openai/tools/toolmgr.py b/pkg/openai/tools/toolmgr.py new file mode 100644 index 00000000..cc160e39 --- /dev/null +++ b/pkg/openai/tools/toolmgr.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import typing + +from ...core import app, entities as core_entities +from . import entities +from ..session import entities as session_entities + + +class ToolManager: + """LLM工具管理器 + """ + + ap: app.Application + + all_functions: list[entities.LLMFunction] + + def __init__(self, ap: app.Application): + self.ap = ap + self.all_functions = [] + + async def initialize(self): + pass + + def register_legacy_function(self, name: str, description: str, parameters: dict, func: callable): + """注册函数 + """ + async def wrapper(query, **kwargs): + return func(**kwargs) + function = entities.LLMFunction( + name=name, + description=description, + human_desc='', + enable=True, + parameters=parameters, + func=wrapper + ) + self.all_functions.append(function) + + async def register_function(self, function: entities.LLMFunction): + """添加函数 + """ + self.all_functions.append(function) + + async def get_function(self, name: str) -> entities.LLMFunction: + """获取函数 + """ + for function in self.all_functions: + if function.name == name: + return function + return None + + async def get_all_functions(self) -> list[entities.LLMFunction]: + """获取所有函数 + """ + return self.all_functions + + async def generate_tools_for_openai(self, conversation: session_entities.Conversation) -> str: + """生成函数列表 + """ + tools = [] + + for function in conversation.use_funcs: + if function.enable: + function_schema = { + "type": "function", + "function": { + "name": function.name, + "description": function.description, + "parameters": function.parameters + } + } + tools.append(function_schema) + + return tools + + async def execute_func_call( + self, + query: core_entities.Query, + name: str, + parameters: dict + ) -> typing.Any: + """执行函数调用 + """ + + # return "i'm not sure for the args "+str(parameters) + + function = await self.get_function(name) + if function is None: + return None + + parameters = parameters.copy() + + parameters = { + "query": query, + **parameters + } + + return await function.func(**parameters) diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 11144891..72c36cdf 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -50,8 +50,8 @@ class LongTextProcessStage(stage.PipelineStage): async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']: - query.message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain))) + query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain))) return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query - ) \ No newline at end of file + ) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index ebe958bf..629c2b11 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -26,13 +26,11 @@ class ChatMessageHandler(handler.MessageHandler): conversation = await self.ap.sess_mgr.get_conversation(session) async for result in conversation.use_model.requester.request(query, conversation): + conversation.messages.append(result) + query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))]) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) - - - -