refactor: 分隔LLM请求过程和消息封装过程

This commit is contained in:
RockChinQ
2024-02-01 15:48:26 +08:00
parent 32162afa65
commit 976a9de39c
11 changed files with 205 additions and 132 deletions

View File

@@ -66,11 +66,12 @@ class Controller:
self.ap.logger.error(f"控制器循环出错: {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
async def _check_output(self, result: pipeline_entities.StageProcessResult):
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出
"""
if result.user_notice:
await self.ap.im_mgr.send(
query.message_event,
result.user_notice
)
if result.debug_notice:
@@ -108,12 +109,14 @@ class Controller:
while i < len(self.ap.stage_mgr.stage_containers):
stage_container = self.ap.stage_mgr.stage_containers[i]
result = await stage_container.inst.process(query, stage_container.inst_name)
result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
await self._check_output(result)
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
@@ -125,7 +128,7 @@ class Controller:
async for sub_result in result:
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
await self._check_output(sub_result)
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")

View File

@@ -27,25 +27,30 @@ class Query(pydantic.BaseModel):
"""一次请求的信息封装"""
query_id: int
"""请求ID"""
"""请求ID,添加进请求池时生成"""
launcher_type: LauncherTypes
"""会话类型"""
"""会话类型platform设置"""
launcher_id: int
"""会话ID"""
"""会话IDplatform设置"""
sender_id: int
"""发送者ID"""
"""发送者IDplatform设置"""
message_event: mirai.MessageEvent
"""事件"""
"""事件platform收到的事件"""
message_chain: mirai.MessageChain
"""消息链"""
"""消息链platform收到的消息链"""
session: typing.Optional[Session] = None
resp_messages: typing.Optional[list[llm_entities.Message]] = []
"""由provider生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链"""
"""回复消息链从resp_messages包装而得"""
class Conversation(pydantic.BaseModel):

View File

@@ -38,7 +38,9 @@ class QueryPool:
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain
message_chain=message_chain,
resp_messages=[],
resp_message_chain=None
)
self.queries.append(query)
self.query_id_counter += 1