mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-28 00:14:21 +00:00
refactor: 重构openai包基础组件架构
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from . import model as file_model
|
from . import model as file_model
|
||||||
from ..utils import context
|
from ..utils import context
|
||||||
from .impls import pymodule, json as json_file
|
from .impls import pymodule, json as json_file
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ import asyncio
|
|||||||
|
|
||||||
from ..qqbot import manager as qqbot_mgr
|
from ..qqbot import manager as qqbot_mgr
|
||||||
from ..openai import manager as openai_mgr
|
from ..openai import manager as openai_mgr
|
||||||
|
from ..openai.session import sessionmgr as llm_session_mgr
|
||||||
|
from ..openai.requester import modelmgr as llm_model_mgr
|
||||||
|
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
|
||||||
from ..config import manager as config_mgr
|
from ..config import manager as config_mgr
|
||||||
from ..database import manager as database_mgr
|
from ..database import manager as database_mgr
|
||||||
from ..utils.center import v2 as center_mgr
|
from ..utils.center import v2 as center_mgr
|
||||||
@@ -18,6 +21,12 @@ class Application:
|
|||||||
|
|
||||||
llm_mgr: openai_mgr.OpenAIInteract = None
|
llm_mgr: openai_mgr.OpenAIInteract = None
|
||||||
|
|
||||||
|
sess_mgr: llm_session_mgr.SessionManager = None
|
||||||
|
|
||||||
|
model_mgr: llm_model_mgr.ModelManager = None
|
||||||
|
|
||||||
|
prompt_mgr: llm_prompt_mgr.PromptManager = None
|
||||||
|
|
||||||
cfg_mgr: config_mgr.ConfigManager = None
|
cfg_mgr: config_mgr.ConfigManager = None
|
||||||
|
|
||||||
tips_mgr: config_mgr.ConfigManager = None
|
tips_mgr: config_mgr.ConfigManager = None
|
||||||
|
|||||||
+15
-3
@@ -15,7 +15,9 @@ from ..pipeline import stagemgr
|
|||||||
from ..audit import identifier
|
from ..audit import identifier
|
||||||
from ..database import manager as db_mgr
|
from ..database import manager as db_mgr
|
||||||
from ..openai import manager as llm_mgr
|
from ..openai import manager as llm_mgr
|
||||||
from ..openai import session as llm_session
|
from ..openai.session import sessionmgr as llm_session_mgr
|
||||||
|
from ..openai.requester import modelmgr as llm_model_mgr
|
||||||
|
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
|
||||||
from ..openai import dprompt as llm_dprompt
|
from ..openai import dprompt as llm_dprompt
|
||||||
from ..qqbot import manager as im_mgr
|
from ..qqbot import manager as im_mgr
|
||||||
from ..qqbot.cmds import aamgr as im_cmd_aamgr
|
from ..qqbot.cmds import aamgr as im_cmd_aamgr
|
||||||
@@ -112,8 +114,18 @@ async def make_app() -> app.Application:
|
|||||||
|
|
||||||
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
|
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
|
||||||
ap.llm_mgr = llm_mgr_inst
|
ap.llm_mgr = llm_mgr_inst
|
||||||
# TODO make it async
|
|
||||||
llm_session.load_sessions()
|
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||||
|
await llm_model_mgr_inst.initialize()
|
||||||
|
ap.model_mgr = llm_model_mgr_inst
|
||||||
|
|
||||||
|
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
|
||||||
|
await llm_session_mgr_inst.initialize()
|
||||||
|
ap.sess_mgr = llm_session_mgr_inst
|
||||||
|
|
||||||
|
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
|
||||||
|
await llm_prompt_mgr_inst.initialize()
|
||||||
|
ap.prompt_mgr = llm_prompt_mgr_inst
|
||||||
|
|
||||||
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
|
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
|
||||||
await im_mgr_inst.initialize()
|
await im_mgr_inst.initialize()
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from ...config import manager as config_mgr
|
from ...config import manager as config_mgr
|
||||||
|
|||||||
+107
-37
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import typing
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from . import app, entities
|
from . import app, entities
|
||||||
@@ -24,25 +25,115 @@ class Controller:
|
|||||||
async def consumer(self):
|
async def consumer(self):
|
||||||
"""事件处理循环
|
"""事件处理循环
|
||||||
"""
|
"""
|
||||||
while True:
|
try:
|
||||||
selected_query: entities.Query = None
|
while True:
|
||||||
|
selected_query: entities.Query = None
|
||||||
|
|
||||||
# 取请求
|
# 取请求
|
||||||
async with self.ap.query_pool:
|
async with self.ap.query_pool:
|
||||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
queries: list[entities.Query] = self.ap.query_pool.queries
|
||||||
|
|
||||||
if queries:
|
for query in queries:
|
||||||
selected_query = queries.pop(0) # FCFS
|
session = await self.ap.sess_mgr.get_session(query)
|
||||||
else:
|
self.ap.logger.debug(f"Checking query {query} session {session}")
|
||||||
await self.ap.query_pool.condition.wait()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if selected_query:
|
if not session.semaphore.locked():
|
||||||
async def _process_query(selected_query):
|
selected_query = query
|
||||||
async with self.semaphore:
|
await session.semaphore.acquire()
|
||||||
await self.process_query(selected_query)
|
|
||||||
|
|
||||||
asyncio.create_task(_process_query(selected_query))
|
break
|
||||||
|
|
||||||
|
if selected_query: # 找到了
|
||||||
|
queries.remove(selected_query)
|
||||||
|
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
|
||||||
|
await self.ap.query_pool.condition.wait()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if selected_query:
|
||||||
|
async def _process_query(selected_query):
|
||||||
|
async with self.semaphore: # 总并发上限
|
||||||
|
await self.process_query(selected_query)
|
||||||
|
|
||||||
|
async with self.ap.query_pool:
|
||||||
|
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||||
|
# 通知其他协程,有新的请求可以处理了
|
||||||
|
self.ap.query_pool.condition.notify_all()
|
||||||
|
|
||||||
|
asyncio.create_task(_process_query(selected_query))
|
||||||
|
except Exception as e:
|
||||||
|
self.ap.logger.error(f"事件处理循环出错: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _check_output(self, result: pipeline_entities.StageProcessResult):
|
||||||
|
"""检查输出
|
||||||
|
"""
|
||||||
|
if result.user_notice:
|
||||||
|
await self.ap.im_mgr.send(
|
||||||
|
result.user_notice
|
||||||
|
)
|
||||||
|
if result.debug_notice:
|
||||||
|
self.ap.logger.debug(result.debug_notice)
|
||||||
|
if result.console_notice:
|
||||||
|
self.ap.logger.info(result.console_notice)
|
||||||
|
|
||||||
|
async def _execute_from_stage(
|
||||||
|
self,
|
||||||
|
stage_index: int,
|
||||||
|
query: entities.Query,
|
||||||
|
):
|
||||||
|
"""从指定阶段开始执行
|
||||||
|
|
||||||
|
如何看懂这里为什么这么写?
|
||||||
|
去问 GPT-4:
|
||||||
|
Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None],
|
||||||
|
如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result,
|
||||||
|
调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
|
||||||
|
Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage:
|
||||||
|
|
||||||
|
A B C D E F G
|
||||||
|
|
||||||
|
如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是:
|
||||||
|
|
||||||
|
A B C D E F G
|
||||||
|
|
||||||
|
现在假设C返回的是AsyncGenerator,那么执行顺序是:
|
||||||
|
|
||||||
|
A B C D E F G C D E F G C D E F G ...
|
||||||
|
Q3: 但是如果不止一个stage会返回生成器呢?
|
||||||
|
"""
|
||||||
|
i = stage_index
|
||||||
|
|
||||||
|
while i < len(self.ap.stage_mgr.stage_containers):
|
||||||
|
stage_container = self.ap.stage_mgr.stage_containers[i]
|
||||||
|
|
||||||
|
result = await stage_container.inst.process(query, stage_container.inst_name)
|
||||||
|
|
||||||
|
|
||||||
|
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||||
|
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
|
||||||
|
await self._check_output(result)
|
||||||
|
|
||||||
|
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||||
|
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||||
|
break
|
||||||
|
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||||
|
query = result.new_query
|
||||||
|
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||||
|
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
|
||||||
|
|
||||||
|
async for sub_result in result:
|
||||||
|
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
|
||||||
|
await self._check_output(sub_result)
|
||||||
|
|
||||||
|
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||||
|
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||||
|
break
|
||||||
|
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||||
|
query = sub_result.new_query
|
||||||
|
await self._execute_from_stage(i + 1, query)
|
||||||
|
break
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
async def process_query(self, query: entities.Query):
|
async def process_query(self, query: entities.Query):
|
||||||
"""处理请求
|
"""处理请求
|
||||||
@@ -50,28 +141,7 @@ class Controller:
|
|||||||
self.ap.logger.debug(f"Processing query {query}")
|
self.ap.logger.debug(f"Processing query {query}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for stage_container in self.ap.stage_mgr.stage_containers:
|
await self._execute_from_stage(0, query)
|
||||||
res = await stage_container.inst.process(query, stage_container.inst_name)
|
|
||||||
|
|
||||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}")
|
|
||||||
|
|
||||||
if res.user_notice:
|
|
||||||
await self.ap.im_mgr.send(
|
|
||||||
query.message_event,
|
|
||||||
res.user_notice
|
|
||||||
)
|
|
||||||
if res.debug_notice:
|
|
||||||
self.ap.logger.debug(res.debug_notice)
|
|
||||||
if res.console_notice:
|
|
||||||
self.ap.logger.info(res.console_notice)
|
|
||||||
|
|
||||||
if res.result_type == pipeline_entities.ResultType.INTERRUPT:
|
|
||||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
|
||||||
break
|
|
||||||
elif res.result_type == pipeline_entities.ResultType.CONTINUE:
|
|
||||||
query = res.new_query
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
|
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
import enum
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(enum.Enum):
|
||||||
|
|
||||||
|
SYSTEM = 'system'
|
||||||
|
|
||||||
|
USER = 'user'
|
||||||
|
|
||||||
|
ASSISTANT = 'assistant'
|
||||||
|
|
||||||
|
FUNCTION = 'function'
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(pydantic.BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
args: dict[str, typing.Any]
|
||||||
|
|
||||||
|
|
||||||
|
class Message(pydantic.BaseModel):
|
||||||
|
|
||||||
|
role: MessageRole
|
||||||
|
|
||||||
|
content: typing.Optional[str] = None
|
||||||
|
|
||||||
|
function_call: typing.Optional[FunctionCall] = None
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from ...core import app
|
||||||
|
from ...core import entities as core_entities
|
||||||
|
from .. import entities as llm_entities
|
||||||
|
from ..session import entities as session_entities
|
||||||
|
|
||||||
|
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||||
|
"""LLM API请求器
|
||||||
|
"""
|
||||||
|
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
self.ap = ap
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def request(
|
||||||
|
self,
|
||||||
|
query: core_entities.Query,
|
||||||
|
conversation: session_entities.Conversation,
|
||||||
|
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
|
"""请求
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from .. import api
|
||||||
|
from ....core import entities as core_entities
|
||||||
|
from ... import entities as llm_entities
|
||||||
|
from ...session import entities as session_entities
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletion(api.LLMAPIRequester):
|
||||||
|
|
||||||
|
client: openai.Client
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
self.client = openai.Client(
|
||||||
|
base_url=self.ap.cfg_mgr.data['openai_config']['reverse_proxy'],
|
||||||
|
timeout=self.ap.cfg_mgr.data['process_message_timeout']
|
||||||
|
)
|
||||||
|
|
||||||
|
async def request(self, query: core_entities.Query, conversation: session_entities.Conversation) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||||
|
"""请求
|
||||||
|
"""
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
yield llm_entities.Message(
|
||||||
|
role=llm_entities.MessageRole.ASSISTANT,
|
||||||
|
content="hello"
|
||||||
|
)
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
import typing
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
from . import api
|
||||||
|
from . import token
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModelInfo(pydantic.BaseModel):
|
||||||
|
"""模型"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
token_mgr: token.TokenManager
|
||||||
|
|
||||||
|
requester: api.LLMAPIRequester
|
||||||
|
|
||||||
|
function_call_supported: typing.Optional[bool] = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from . import entities
|
||||||
|
from ...core import app
|
||||||
|
|
||||||
|
from .apis import chatcmpl
|
||||||
|
from . import token
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
model_list: list[entities.LLMModelInfo]
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
self.ap = ap
|
||||||
|
self.model_list = []
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
|
||||||
|
openai_token_mgr = token.TokenManager(self.ap, self.ap.cfg_mgr.data['openai_config']['api_key'].values())
|
||||||
|
|
||||||
|
self.model_list.append(
|
||||||
|
entities.LLMModelInfo(
|
||||||
|
name="gpt-3.5-turbo",
|
||||||
|
provider="openai",
|
||||||
|
token_mgr=openai_token_mgr,
|
||||||
|
requester=openai_chat_completion,
|
||||||
|
function_call_supported=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
||||||
|
"""通过名称获取模型
|
||||||
|
"""
|
||||||
|
for model in self.model_list:
|
||||||
|
if model.name == name:
|
||||||
|
return model
|
||||||
|
raise ValueError(f"Model {name} not found")
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
|
class TokenManager():
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
tokens: list[str]
|
||||||
|
|
||||||
|
using_token_index: typing.Optional[int] = 0
|
||||||
|
|
||||||
|
def __init__(self, provider: str, tokens: list[str]):
|
||||||
|
self.provider = provider
|
||||||
|
self.tokens = tokens
|
||||||
|
self.using_token_index = 0
|
||||||
|
|
||||||
|
def get_token(self) -> str:
|
||||||
|
return self.tokens[self.using_token_index]
|
||||||
|
|
||||||
|
def next_token(self):
|
||||||
|
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import asyncio
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
from ..sysprompt import entities as sysprompt_entities
|
||||||
|
from .. import entities as llm_entities
|
||||||
|
from ..requester import entities
|
||||||
|
from ...core import entities as core_entities
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(pydantic.BaseModel):
|
||||||
|
"""对话"""
|
||||||
|
|
||||||
|
prompt: sysprompt_entities.Prompt
|
||||||
|
|
||||||
|
messages: list[llm_entities.Message]
|
||||||
|
|
||||||
|
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||||
|
|
||||||
|
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||||
|
|
||||||
|
use_model: entities.LLMModelInfo
|
||||||
|
|
||||||
|
|
||||||
|
class Session(pydantic.BaseModel):
|
||||||
|
"""会话"""
|
||||||
|
launcher_type: core_entities.LauncherTypes
|
||||||
|
|
||||||
|
launcher_id: int
|
||||||
|
|
||||||
|
sender_id: typing.Optional[int] = 0
|
||||||
|
|
||||||
|
use_prompt_name: typing.Optional[str] = 'default'
|
||||||
|
|
||||||
|
using_conversation: typing.Optional[Conversation] = None
|
||||||
|
|
||||||
|
conversations: typing.Optional[list[Conversation]] = []
|
||||||
|
|
||||||
|
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||||
|
|
||||||
|
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||||
|
|
||||||
|
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from ...core import app, entities as core_entities
|
||||||
|
from . import entities
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
session_list: list[entities.Session]
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
self.ap = ap
|
||||||
|
self.session_list = []
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_session(self, query: core_entities.Query) -> entities.Session:
|
||||||
|
"""获取会话
|
||||||
|
"""
|
||||||
|
for session in self.session_list:
|
||||||
|
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
|
||||||
|
return session
|
||||||
|
|
||||||
|
session = entities.Session(
|
||||||
|
launcher_type=query.launcher_type,
|
||||||
|
launcher_id=query.launcher_id,
|
||||||
|
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000)
|
||||||
|
)
|
||||||
|
self.session_list.append(session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
async def get_conversation(self, session: entities.Session) -> entities.Conversation:
|
||||||
|
if not session.conversations:
|
||||||
|
session.conversations = []
|
||||||
|
|
||||||
|
if session.using_conversation is None:
|
||||||
|
conversation = entities.Conversation(
|
||||||
|
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||||
|
messages=[],
|
||||||
|
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),
|
||||||
|
)
|
||||||
|
session.conversations.append(conversation)
|
||||||
|
session.using_conversation = conversation
|
||||||
|
|
||||||
|
return session.using_conversation
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
from ...openai import entities
|
||||||
|
|
||||||
|
|
||||||
|
class Prompt(pydantic.BaseModel):
|
||||||
|
"""供AI使用的Prompt"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
messages: list[entities.Message]
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import abc
|
||||||
|
|
||||||
|
from ...core import app
|
||||||
|
from . import entities
|
||||||
|
|
||||||
|
|
||||||
|
class PromptLoader(metaclass=abc.ABCMeta):
|
||||||
|
"""Prompt加载器抽象类
|
||||||
|
"""
|
||||||
|
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
prompts: list[entities.Prompt]
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
self.ap = ap
|
||||||
|
self.prompts = []
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def load(self):
|
||||||
|
"""加载Prompt
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_prompts(self) -> list[entities.Prompt]:
|
||||||
|
"""获取Prompt列表
|
||||||
|
"""
|
||||||
|
return self.prompts
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .. import loader
|
||||||
|
from .. import entities
|
||||||
|
from ....openai import entities as llm_entities
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioPromptLoader(loader.PromptLoader):
|
||||||
|
"""加载scenario目录下的json"""
|
||||||
|
|
||||||
|
async def load(self):
|
||||||
|
"""加载Prompt
|
||||||
|
"""
|
||||||
|
for file in os.listdir("scenarios"):
|
||||||
|
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
|
||||||
|
file_str = f.read()
|
||||||
|
file_name = file.split(".")[0]
|
||||||
|
file_json = json.loads(file_str)
|
||||||
|
messages = []
|
||||||
|
for msg in file_json["prompt"]:
|
||||||
|
role = llm_entities.MessageRole.SYSTEM
|
||||||
|
if "role" in msg:
|
||||||
|
if msg["role"] == "user":
|
||||||
|
role = llm_entities.MessageRole.USER
|
||||||
|
elif msg["role"] == "system":
|
||||||
|
role = llm_entities.MessageRole.SYSTEM
|
||||||
|
elif msg["role"] == "function":
|
||||||
|
role = llm_entities.MessageRole.FUNCTION
|
||||||
|
messages.append(
|
||||||
|
llm_entities.Message(
|
||||||
|
role=role,
|
||||||
|
content=msg['content'],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prompt = entities.Prompt(
|
||||||
|
name=file_name,
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
self.prompts.append(prompt)
|
||||||
|
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .. import loader
|
||||||
|
from .. import entities
|
||||||
|
from ....openai import entities as llm_entities
|
||||||
|
|
||||||
|
|
||||||
|
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||||
|
"""配置文件中的单条system prompt的prompt加载器
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def load(self):
|
||||||
|
"""加载Prompt
|
||||||
|
"""
|
||||||
|
|
||||||
|
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
|
||||||
|
prompt = entities.Prompt(
|
||||||
|
name=name,
|
||||||
|
messages=[
|
||||||
|
llm_entities.Message(
|
||||||
|
role=llm_entities.MessageRole.SYSTEM,
|
||||||
|
content=cnt
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.prompts.append(prompt)
|
||||||
|
|
||||||
|
for file in os.listdir("prompts"):
|
||||||
|
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
|
||||||
|
file_str = f.read()
|
||||||
|
file_name = file.split(".")[0]
|
||||||
|
prompt = entities.Prompt(
|
||||||
|
name=file_name,
|
||||||
|
messages=[
|
||||||
|
llm_entities.Message(
|
||||||
|
role=llm_entities.MessageRole.SYSTEM,
|
||||||
|
content=file_str
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.prompts.append(prompt)
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ...core import app
|
||||||
|
from . import loader
|
||||||
|
from .loaders import single, scenario
|
||||||
|
|
||||||
|
|
||||||
|
class PromptManager:
|
||||||
|
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
loader_inst: loader.PromptLoader
|
||||||
|
|
||||||
|
default_prompt: str = 'default'
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
self.ap = ap
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
|
||||||
|
loader_map = {
|
||||||
|
"normal": single.SingleSystemPromptLoader,
|
||||||
|
"full_scenario": scenario.ScenarioPromptLoader
|
||||||
|
}
|
||||||
|
|
||||||
|
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
|
||||||
|
|
||||||
|
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
||||||
|
|
||||||
|
await self.loader_inst.initialize()
|
||||||
|
await self.loader_inst.load()
|
||||||
|
|
||||||
|
def get_all_prompts(self) -> list[loader.entities.Prompt]:
|
||||||
|
"""获取所有Prompt
|
||||||
|
"""
|
||||||
|
return self.loader_inst.get_prompts()
|
||||||
|
|
||||||
|
async def get_prompt(self, name: str) -> loader.entities.Prompt:
|
||||||
|
"""获取Prompt
|
||||||
|
"""
|
||||||
|
for prompt in self.get_all_prompts():
|
||||||
|
if prompt.name == name:
|
||||||
|
return prompt
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
|
||||||
|
from ...core import app
|
||||||
|
from ...core import entities as core_entities
|
||||||
|
from .. import entities
|
||||||
|
|
||||||
|
|
||||||
|
class MessageHandler(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
|
ap: app.Application
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
self.ap = ap
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def handle(
|
||||||
|
self,
|
||||||
|
query: core_entities.Query,
|
||||||
|
) -> entities.StageProcessResult:
|
||||||
|
raise NotImplementedError
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import mirai
|
||||||
|
|
||||||
|
from .. import handler
|
||||||
|
from ... import entities
|
||||||
|
from ....core import entities as core_entities
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageHandler(handler.MessageHandler):
|
||||||
|
|
||||||
|
async def handle(
|
||||||
|
self,
|
||||||
|
query: core_entities.Query,
|
||||||
|
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||||
|
"""处理
|
||||||
|
"""
|
||||||
|
# 取session
|
||||||
|
# 取conversation
|
||||||
|
# 调API
|
||||||
|
# 生成器
|
||||||
|
session = await self.ap.sess_mgr.get_session(query)
|
||||||
|
|
||||||
|
conversation = await self.ap.sess_mgr.get_conversation(session)
|
||||||
|
|
||||||
|
async for result in conversation.use_model.requester.request(query, conversation):
|
||||||
|
query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))])
|
||||||
|
|
||||||
|
yield entities.StageProcessResult(
|
||||||
|
result_type=entities.ResultType.CONTINUE,
|
||||||
|
new_query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import mirai
|
||||||
|
|
||||||
|
from .. import handler
|
||||||
|
from ... import entities
|
||||||
|
from ....core import entities as core_entities
|
||||||
|
|
||||||
|
|
||||||
|
class CommandHandler(handler.MessageHandler):
|
||||||
|
|
||||||
|
async def handle(
|
||||||
|
self,
|
||||||
|
query: core_entities.Query,
|
||||||
|
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||||
|
"""处理
|
||||||
|
"""
|
||||||
|
query.resp_message_chain = mirai.MessageChain([
|
||||||
|
mirai.Plain('CommandHandler')
|
||||||
|
])
|
||||||
|
|
||||||
|
yield entities.StageProcessResult(
|
||||||
|
result_type=entities.ResultType.CONTINUE,
|
||||||
|
new_query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
query.resp_message_chain = mirai.MessageChain([
|
||||||
|
mirai.Plain('The Second Message')
|
||||||
|
])
|
||||||
|
|
||||||
|
yield entities.StageProcessResult(
|
||||||
|
result_type=entities.ResultType.CONTINUE,
|
||||||
|
new_query=query
|
||||||
|
)
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ...core import app, entities as core_entities
|
||||||
|
from . import handler
|
||||||
|
from .handlers import chat, command
|
||||||
|
from .. import entities
|
||||||
|
from .. import stage, entities, stagemgr
|
||||||
|
from ...core import entities as core_entities
|
||||||
|
from ...config import manager as cfg_mgr
|
||||||
|
|
||||||
|
|
||||||
|
@stage.stage_class("MessageProcessor")
|
||||||
|
class Processor(stage.PipelineStage):
|
||||||
|
|
||||||
|
cmd_handler: handler.MessageHandler
|
||||||
|
|
||||||
|
chat_handler: handler.MessageHandler
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
self.cmd_handler = command.CommandHandler(self.ap)
|
||||||
|
self.chat_handler = chat.ChatMessageHandler(self.ap)
|
||||||
|
|
||||||
|
await self.cmd_handler.initialize()
|
||||||
|
await self.chat_handler.initialize()
|
||||||
|
|
||||||
|
async def process(
|
||||||
|
self,
|
||||||
|
query: core_entities.Query,
|
||||||
|
stage_inst_name: str,
|
||||||
|
) -> entities.StageProcessResult:
|
||||||
|
"""处理
|
||||||
|
"""
|
||||||
|
message_text = str(query.message_chain).strip()
|
||||||
|
|
||||||
|
if message_text.startswith('!') or message_text.startswith('!'):
|
||||||
|
return self.cmd_handler.handle(query)
|
||||||
|
else:
|
||||||
|
return self.chat_handler.handle(query)
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import mirai
|
||||||
|
|
||||||
|
from ...core import app
|
||||||
|
|
||||||
|
from .. import stage, entities, stagemgr
|
||||||
|
from ...core import entities as core_entities
|
||||||
|
from ...config import manager as cfg_mgr
|
||||||
|
|
||||||
|
|
||||||
|
@stage.stage_class("SendResponseBackStage")
|
||||||
|
class SendResponseBackStage(stage.PipelineStage):
|
||||||
|
"""发送响应消息
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
|
"""处理
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.ap.im_mgr.send(
|
||||||
|
query.message_event,
|
||||||
|
query.resp_message_chain
|
||||||
|
)
|
||||||
|
|
||||||
|
return entities.StageProcessResult(
|
||||||
|
result_type=entities.ResultType.CONTINUE,
|
||||||
|
new_query=query
|
||||||
|
)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import typing
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import app, entities as core_entities
|
||||||
from . import entities
|
from . import entities
|
||||||
@@ -37,7 +38,10 @@ class PipelineStage(metaclass=abc.ABCMeta):
|
|||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: core_entities.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> entities.StageProcessResult:
|
) -> typing.Union[
|
||||||
|
entities.StageProcessResult,
|
||||||
|
typing.AsyncGenerator[entities.StageProcessResult, None],
|
||||||
|
]:
|
||||||
"""处理
|
"""处理
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -7,7 +7,20 @@ from . import stage
|
|||||||
from .resprule import resprule
|
from .resprule import resprule
|
||||||
from .bansess import bansess
|
from .bansess import bansess
|
||||||
from .cntfilter import cntfilter
|
from .cntfilter import cntfilter
|
||||||
|
from .process import process
|
||||||
from .longtext import longtext
|
from .longtext import longtext
|
||||||
|
from .respback import respback
|
||||||
|
|
||||||
|
|
||||||
|
stage_order = [
|
||||||
|
"GroupRespondRuleCheckStage",
|
||||||
|
"BanSessionCheckStage",
|
||||||
|
"PreContentFilterStage",
|
||||||
|
"MessageProcessor",
|
||||||
|
"PostContentFilterStage",
|
||||||
|
"LongTextProcessStage",
|
||||||
|
"SendResponseBackStage",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class StageInstContainer():
|
class StageInstContainer():
|
||||||
@@ -45,3 +58,6 @@ class StageManager:
|
|||||||
|
|
||||||
for stage_containers in self.stage_containers:
|
for stage_containers in self.stage_containers:
|
||||||
await stage_containers.inst.initialize()
|
await stage_containers.inst.initialize()
|
||||||
|
|
||||||
|
# 按照 stage_order 排序
|
||||||
|
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))
|
||||||
|
|||||||
@@ -18,10 +18,6 @@ from ..plugin import host as plugin_host
|
|||||||
from ..plugin import models as plugin_models
|
from ..plugin import models as plugin_models
|
||||||
import tips as tips_custom
|
import tips as tips_custom
|
||||||
from ..qqbot import adapter as msadapter
|
from ..qqbot import adapter as msadapter
|
||||||
from .resprule import resprule
|
|
||||||
from .bansess import bansess
|
|
||||||
from .cntfilter import cntfilter
|
|
||||||
from .longtext import longtext
|
|
||||||
from .ratelim import ratelim
|
from .ratelim import ratelim
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import app, entities as core_entities
|
||||||
@@ -41,30 +37,18 @@ class QQBotManager:
|
|||||||
# modern
|
# modern
|
||||||
ap: app.Application = None
|
ap: app.Application = None
|
||||||
|
|
||||||
bansess_mgr: bansess.SessionBanManager = None
|
|
||||||
cntfilter_mgr: cntfilter.ContentFilterManager = None
|
|
||||||
longtext_pcs: longtext.LongTextProcessor = None
|
|
||||||
resprule_chkr: resprule.GroupRespondRuleChecker = None
|
|
||||||
ratelimiter: ratelim.RateLimiter = None
|
ratelimiter: ratelim.RateLimiter = None
|
||||||
|
|
||||||
def __init__(self, first_time_init=True, ap: app.Application = None):
|
def __init__(self, first_time_init=True, ap: app.Application = None):
|
||||||
config = context.get_config_manager().data
|
config = context.get_config_manager().data
|
||||||
|
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.bansess_mgr = bansess.SessionBanManager(ap)
|
|
||||||
self.cntfilter_mgr = cntfilter.ContentFilterManager(ap)
|
|
||||||
self.longtext_pcs = longtext.LongTextProcessor(ap)
|
|
||||||
self.resprule_chkr = resprule.GroupRespondRuleChecker(ap)
|
|
||||||
self.ratelimiter = ratelim.RateLimiter(ap)
|
self.ratelimiter = ratelim.RateLimiter(ap)
|
||||||
|
|
||||||
self.timeout = config['process_message_timeout']
|
self.timeout = config['process_message_timeout']
|
||||||
self.retry = config['retry_times']
|
self.retry = config['retry_times']
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
await self.bansess_mgr.initialize()
|
|
||||||
await self.cntfilter_mgr.initialize()
|
|
||||||
await self.longtext_pcs.initialize()
|
|
||||||
await self.resprule_chkr.initialize()
|
|
||||||
await self.ratelimiter.initialize()
|
await self.ratelimiter.initialize()
|
||||||
|
|
||||||
config = context.get_config_manager().data
|
config = context.get_config_manager().data
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from ..plugin import host as plugin_host
|
|||||||
from ..plugin import models as plugin_models
|
from ..plugin import models as plugin_models
|
||||||
import tips as tips_custom
|
import tips as tips_custom
|
||||||
from ..core import app
|
from ..core import app
|
||||||
from .cntfilter import entities
|
# from .cntfilter import entities
|
||||||
|
|
||||||
processing = []
|
processing = []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user