feat: runner 层抽象 (#839)

This commit is contained in:
RockChinQ
2024-07-28 18:45:27 +08:00
parent 48cc3656bd
commit 8cad4089a7
10 changed files with 172 additions and 65 deletions
+4 -63
View File
@@ -10,7 +10,7 @@ import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....provider import entities as llm_entities, runnermgr
from ....plugin import events
@@ -71,7 +71,9 @@ class ChatMessageHandler(handler.MessageHandler):
try:
async for result in self.runner(query):
runner = self.ap.runner_mgr.get_runner()
async for result in runner.run(query):
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
@@ -108,64 +110,3 @@ class ChatMessageHandler(handler.MessageHandler):
response_seconds=int(time.time() - start_time),
retry_times=-1,
)
async def runner(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)