refactor: 分隔LLM请求过程和消息封装过程

This commit is contained in:
RockChinQ
2024-02-01 15:48:26 +08:00
parent 32162afa65
commit 976a9de39c
11 changed files with 205 additions and 132 deletions
+5 -88
View File
@@ -83,104 +83,21 @@ class ChatMessageHandler(handler.MessageHandler):
)
)
called_functions = []
text_length = 0
start_time = time.time()
async for result in conversation.use_model.requester.request(query, conversation):
conversation.messages.append(result)
query.resp_messages.append(result)
if result.content is not None:
text_length += len(result.content)
# 转换成可读消息
if result.role == 'assistant':
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
reply_text = ''
if result.content is not None: # 有内容
reply_text = result.content
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=called_functions,
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
called_functions.extend(function_names)
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.cfg_mgr.data['trace_function_calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=called_functions,
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
await self.ap.ctr_mgr.usage.post_query_record(
session_type=session.launcher_type.value,
session_id=str(session.launcher_id),
+27 -7
View File
@@ -6,6 +6,7 @@ import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....plugin import events
@@ -44,7 +45,14 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(
llm_entities.Message(
role='command',
content=str(mc),
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@@ -73,18 +81,30 @@ class CommandHandler(handler.MessageHandler):
session=session
):
if ret.error is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(str(ret.error))
])
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(str(ret.error))
# ])
query.resp_messages.append(
llm_entities.Message(
role='command',
content=str(ret.error),
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif ret.text is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(ret.text)
])
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(ret.text)
# ])
query.resp_messages.append(
llm_entities.Message(
role='command',
content=ret.text,
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,