refactor: 重构openai包基础组件架构

This commit is contained in:
RockChinQ
2024-01-27 00:06:38 +08:00
parent 411034902a
commit 850a4eeb7c
35 changed files with 779 additions and 59 deletions

View File

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
import abc
from ...core import app
from ...core import entities as core_entities
from .. import entities
class MessageHandler(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def handle(
self,
query: core_entities.Query,
) -> entities.StageProcessResult:
raise NotImplementedError

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
import typing
import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
class ChatMessageHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
# 取session
# 取conversation
# 调API
# 生成器
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
async for result in conversation.use_model.requester.request(query, conversation):
query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
import typing
import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
class CommandHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
query.resp_message_chain = mirai.MessageChain([
mirai.Plain('CommandHandler')
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
query.resp_message_chain = mirai.MessageChain([
mirai.Plain('The Second Message')
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from ...core import app, entities as core_entities
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage):
cmd_handler: handler.MessageHandler
chat_handler: handler.MessageHandler
async def initialize(self):
self.cmd_handler = command.CommandHandler(self.ap)
self.chat_handler = chat.ChatMessageHandler(self.ap)
await self.cmd_handler.initialize()
await self.chat_handler.initialize()
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理
"""
message_text = str(query.message_chain).strip()
if message_text.startswith('!') or message_text.startswith(''):
return self.cmd_handler.handle(query)
else:
return self.chat_handler.handle(query)

View File

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
import mirai
from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("SendResponseBackStage")
class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息
"""
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
await self.ap.im_mgr.send(
query.message_event,
query.resp_message_chain
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import abc
import typing
from ..core import app, entities as core_entities
from . import entities
@@ -37,7 +38,10 @@ class PipelineStage(metaclass=abc.ABCMeta):
self,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
) -> typing.Union[
entities.StageProcessResult,
typing.AsyncGenerator[entities.StageProcessResult, None],
]:
"""处理
"""
raise NotImplementedError

View File

@@ -7,7 +7,20 @@ from . import stage
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
stage_order = [
"GroupRespondRuleCheckStage",
"BanSessionCheckStage",
"PreContentFilterStage",
"MessageProcessor",
"PostContentFilterStage",
"LongTextProcessStage",
"SendResponseBackStage",
]
class StageInstContainer():
@@ -45,3 +58,6 @@ class StageManager:
for stage_containers in self.stage_containers:
await stage_containers.inst.initialize()
# 按照 stage_order 排序
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))