feat: 消息平台热重载

This commit is contained in:
Junyan Qin
2024-11-16 12:40:57 +08:00
parent 3239c9ec3f
commit bb219889e5
14 changed files with 168 additions and 36 deletions

13
main.py
View File

@@ -3,13 +3,14 @@
# QChatGPT/main.py # QChatGPT/main.py
asciiart = r""" asciiart = r"""
___ ___ _ _ ___ ___ _____ _ ___ _
/ _ \ / __| |_ __ _| |_ / __| _ \_ _| | | __ _ _ _ __ _| _ ) ___| |_
| (_) | (__| ' \/ _` | _| (_ | _/ | | | |__/ _` | ' \/ _` | _ \/ _ \ _|
\__\_\\___|_||_\__,_|\__|\___|_| |_| |____\__,_|_||_\__, |___/\___/\__|
|___/
⭐️开源地址: https://github.com/RockChinQ/QChatGPT ⭐️开源地址: https://github.com/RockChinQ/LangBot
📖文档地址: https://q.rkcn.top 📖文档地址: https://docs.langbot.app
""" """

View File

@@ -39,3 +39,25 @@ class SystemRouterGroup(group.RouterGroup):
return self.http_status(404, 404, "Task not found") return self.http_status(404, 404, "Task not found")
return self.success(data=task.to_dict()) return self.success(data=task.to_dict())
@self.route('/reload', methods=['POST'])
async def _() -> str:
json_data = await quart.request.json
scope = json_data.get("scope")
await self.ap.reload(
scope=scope
)
return self.success()
@self.route('/_debug/exec', methods=['POST'])
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, "Forbidden")
py_code = await quart.request.data
ap = self.ap
return self.success(data=exec(py_code, {"ap": ap}))

View File

@@ -6,7 +6,7 @@ import os
import quart import quart
import quart_cors import quart_cors
from ....core import app from ....core import app, entities as core_entities
from .groups import logs, system, settings, plugins, stats from .groups import logs, system, settings, plugins, stats
from . import group from . import group
@@ -32,15 +32,26 @@ class HTTPController:
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
async def exception_handler(*args, **kwargs):
try:
await self.quart_app.run_task(
*args, **kwargs
)
except Exception as e:
self.ap.logger.error(f"启动 HTTP 服务失败: {e}")
self.ap.task_mgr.create_task( self.ap.task_mgr.create_task(
self.quart_app.run_task( exception_handler(
host=self.ap.system_cfg.data["http-api"]["host"], host=self.ap.system_cfg.data["http-api"]["host"],
port=self.ap.system_cfg.data["http-api"]["port"], port=self.ap.system_cfg.data["http-api"]["port"],
shutdown_trigger=shutdown_trigger_placeholder, shutdown_trigger=shutdown_trigger_placeholder,
), ),
name="http-api-quart", name="http-api-quart",
scopes=[core_entities.LifecycleControlScope.APPLICATION],
) )
# await asyncio.sleep(5)
async def register_routes(self) -> None: async def register_routes(self) -> None:
@self.quart_app.route("/healthz") @self.quart_app.route("/healthz")

View File

@@ -9,7 +9,7 @@ import asyncio
import aiohttp import aiohttp
import requests import requests
from ...core import app from ...core import app, entities as core_entities
class APIGroup(metaclass=abc.ABCMeta): class APIGroup(metaclass=abc.ABCMeta):
@@ -65,14 +65,12 @@ class APIGroup(metaclass=abc.ABCMeta):
**kwargs, **kwargs,
) -> asyncio.Task: ) -> asyncio.Task:
"""执行请求""" """执行请求"""
# task = asyncio.create_task(self._do(method, path, data, params, headers, **kwargs))
# self.ap.asyncio_tasks.append(task)
return self.ap.task_mgr.create_task( return self.ap.task_mgr.create_task(
self._do(method, path, data, params, headers, **kwargs), self._do(method, path, data, params, headers, **kwargs),
kind="telemetry-operation", kind="telemetry-operation",
name=f"{method} {path}", name=f"{method} {path}",
scopes=[core_entities.LifecycleControlScope.APPLICATION],
).task ).task
def gen_rid(self): def gen_rid(self):

View File

@@ -4,6 +4,7 @@ import logging
import asyncio import asyncio
import threading import threading
import traceback import traceback
import enum
from ..platform import manager as im_mgr from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr from ..provider.session import sessionmgr as llm_session_mgr
@@ -21,8 +22,9 @@ from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller from ..api.http.controller import main as http_controller
from ..utils import logcache from ..utils import logcache, ip
from . import taskmgr from . import taskmgr
from . import entities as core_entities
class Application: class Application:
@@ -114,11 +116,12 @@ class Application:
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager") self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
self.task_mgr.create_task(self.ctrl.run(), name="query-controller") self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller") self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(never_ending(), name="never-ending-task") self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION])
await self.print_web_access_info()
await self.task_mgr.wait_all() await self.task_mgr.wait_all()
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
@@ -126,9 +129,45 @@ class Application:
self.logger.error(f"应用运行致命异常: {e}") self.logger.error(f"应用运行致命异常: {e}")
self.logger.debug(f"Traceback: {traceback.format_exc()}") self.logger.debug(f"Traceback: {traceback.format_exc()}")
async def scoped_shutdown(self, scopes: list[str]): async def print_web_access_info(self):
pass """打印访问 webui 的提示"""
import socket
async def shutdown(self): host_ip = socket.gethostbyname(socket.gethostname())
for task in self.task_mgr.tasks:
task.cancel() public_ip = await ip.get_myip()
port = self.system_cfg.data['http-api']['port']
tips = f"""
=======================================
✨ 您可通过以下方式访问管理面板:
🏠 本地地址http://{host_ip}:{port}/
🌐 公网地址http://{public_ip}:{port}/
📌 如果您在容器中运行此程序,请确保容器的 {port} 端口已对外暴露
🔗 若要使用公网地址访问,请阅读以下须知
1. 公网地址仅供参考,请以您的主机公网 IP 为准;
2. 要使用公网地址访问,请确保您的主机具有公网 IP并且系统防火墙已放行 {port} 端口;
=======================================
""".strip()
for line in tips.split("\n"):
self.logger.info(line)
async def reload(
self,
scope: core_entities.LifecycleControlScope,
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info("执行热重载 scope="+scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
case _:
pass

View File

@@ -6,7 +6,7 @@ required_deps = {
"anthropic": "anthropic", "anthropic": "anthropic",
"colorlog": "colorlog", "colorlog": "colorlog",
"aiocqhttp": "aiocqhttp", "aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy", "botpy": "qq-botpy-rc",
"PIL": "pillow", "PIL": "pillow",
"nakuru": "nakuru-project-idk", "nakuru": "nakuru-project-idk",
"tiktoken": "tiktoken", "tiktoken": "tiktoken",

View File

@@ -17,6 +17,14 @@ from ..platform.types import events as platform_events
from ..platform.types import entities as platform_entities from ..platform.types import entities as platform_entities
class LifecycleControlScope(enum.Enum):
APPLICATION = "application"
PLATFORM = "platform"
PLUGIN = "plugin"
class LauncherTypes(enum.Enum): class LauncherTypes(enum.Enum):
"""一个请求的发起者类型""" """一个请求的发起者类型"""

View File

@@ -6,6 +6,7 @@ import datetime
import traceback import traceback
from . import app from . import app
from . import entities as core_entities
class TaskContext: class TaskContext:
@@ -71,7 +72,7 @@ class TaskWrapper:
task_type: str = "system" # 任务类型: system 或 user task_type: str = "system" # 任务类型: system 或 user
"""任务类型""" """任务类型"""
kind: str = "system_task" kind: str = "system_task" # 由发起者确定任务种类,通常同质化的任务种类相同
"""任务种类""" """任务种类"""
name: str = "" name: str = ""
@@ -92,6 +93,9 @@ class TaskWrapper:
ap: app.Application ap: app.Application
"""应用实例""" """应用实例"""
scopes: list[core_entities.LifecycleControlScope]
"""任务所属生命周期控制范围"""
def __init__( def __init__(
self, self,
ap: app.Application, ap: app.Application,
@@ -101,6 +105,7 @@ class TaskWrapper:
name: str = "", name: str = "",
label: str = "", label: str = "",
context: TaskContext = None, context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
): ):
self.id = TaskWrapper._id_index self.id = TaskWrapper._id_index
TaskWrapper._id_index += 1 TaskWrapper._id_index += 1
@@ -112,6 +117,7 @@ class TaskWrapper:
self.name = name self.name = name
self.label = label if label != "" else name self.label = label if label != "" else name
self.task.set_name(name) self.task.set_name(name)
self.scopes = scopes
def assume_exception(self): def assume_exception(self):
try: try:
@@ -145,6 +151,7 @@ class TaskWrapper:
"kind": self.kind, "kind": self.kind,
"name": self.name, "name": self.name,
"label": self.label, "label": self.label,
"scopes": [scope.value for scope in self.scopes],
"task_context": self.task_context.to_dict(), "task_context": self.task_context.to_dict(),
"runtime": { "runtime": {
"done": self.task.done(), "done": self.task.done(),
@@ -180,8 +187,9 @@ class AsyncTaskManager:
name: str = "", name: str = "",
label: str = "", label: str = "",
context: TaskContext = None, context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper: ) -> TaskWrapper:
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context) wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
self.tasks.append(wrapper) self.tasks.append(wrapper)
return wrapper return wrapper
@@ -192,8 +200,9 @@ class AsyncTaskManager:
name: str = "", name: str = "",
label: str = "", label: str = "",
context: TaskContext = None, context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper: ) -> TaskWrapper:
return self.create_task(coro, "user", kind, name, label, context) return self.create_task(coro, "user", kind, name, label, context, scopes)
async def wait_all(self): async def wait_all(self):
await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True) await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True)
@@ -217,3 +226,10 @@ class AsyncTaskManager:
if t.id == id: if t.id == id:
return t return t
return None return None
def cancel_by_scope(self, scope: core_entities.LifecycleControlScope):
for wrapper in self.tasks:
if not wrapper.task.done() and scope in wrapper.scopes:
wrapper.task.cancel()

View File

@@ -4,7 +4,6 @@ import asyncio
import typing import typing
import traceback import traceback
from ..core import app, entities from ..core import app, entities
from . import entities as pipeline_entities from . import entities as pipeline_entities
from ..plugin import events from ..plugin import events
@@ -59,13 +58,11 @@ class Controller:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了 # 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all() self.ap.query_pool.condition.notify_all()
# task = asyncio.create_task(_process_query(selected_query))
# self.ap.asyncio_tasks.append(task)
self.ap.task_mgr.create_task( self.ap.task_mgr.create_task(
_process_query(selected_query), _process_query(selected_query),
kind="query", kind="query",
name=f"query-{selected_query.query_id}", name=f"query-{selected_query.query_id}",
scopes=[entities.LifecycleControlScope.APPLICATION, entities.LifecycleControlScope.PLATFORM],
) )
except Exception as e: except Exception as e:

View File

@@ -174,22 +174,23 @@ class PlatformManager:
try: try:
tasks = [] tasks = []
for adapter in self.adapters: for adapter in self.adapters:
async def exception_wrapper(adapter): async def exception_wrapper(adapter: msadapter.MessageSourceAdapter):
try: try:
await adapter.run_async() await adapter.run_async()
except Exception as e: except Exception as e:
if isinstance(e, asyncio.CancelledError):
return
self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.error('平台适配器运行出错: ' + str(e))
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
tasks.append(exception_wrapper(adapter)) tasks.append(exception_wrapper(adapter))
for task in tasks: for task in tasks:
# async_task = asyncio.create_task(task)
# self.ap.asyncio_tasks.append(async_task)
self.ap.task_mgr.create_task( self.ap.task_mgr.create_task(
task, task,
kind="platform-adapter", kind="platform-adapter",
name=f"platform-adapter-{adapter.name}", name=f"platform-adapter-{adapter.name}",
scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM],
) )
except Exception as e: except Exception as e:
@@ -199,3 +200,4 @@ class PlatformManager:
async def shutdown(self): async def shutdown(self):
for adapter in self.adapters: for adapter in self.adapters:
await adapter.kill() await adapter.kill()
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)

View File

@@ -588,8 +588,12 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
self.member_openid_mapping, self.group_openid_mapping self.member_openid_mapping, self.group_openid_mapping
) )
self.cfg['ret_coro'] = True
self.ap.logger.info("运行 QQ 官方适配器") self.ap.logger.info("运行 QQ 官方适配器")
await self.bot.start(**self.cfg) await (await self.bot.start(**self.cfg))
async def kill(self) -> bool: async def kill(self) -> bool:
return False if not self.bot.is_closed():
await self.bot.close()
return True

9
pkg/utils/ip.py Normal file
View File

@@ -0,0 +1,9 @@
import aiohttp
async def get_myip() -> str:
try:
async with aiohttp.ClientSession() as session:
async with session.get("https://ip.useragentinfo.com/myip") as response:
return await response.text()
except Exception as e:
return '0.0.0.0'

View File

@@ -3,7 +3,7 @@ openai>1.0.0
anthropic anthropic
colorlog~=6.6.0 colorlog~=6.6.0
aiocqhttp aiocqhttp
qq-botpy qq-botpy-rc
nakuru-project-idk nakuru-project-idk
Pillow Pillow
tiktoken tiktoken

View File

@@ -67,6 +67,12 @@
</v-list-item-title> </v-list-item-title>
</v-list-item> </v-list-item>
<v-list-item @click="reload('platform')">
<v-list-item-title>
重载消息平台
</v-list-item-title>
</v-list-item>
</v-list> </v-list>
</v-menu> </v-menu>
</v-list-item> </v-list-item>
@@ -137,6 +143,25 @@ function openDocs() {
window.open('https://docs.langbot.app', '_blank') window.open('https://docs.langbot.app', '_blank')
} }
function reload(scope) {
proxy.$axios.post('/system/reload',
{ scope: scope },
{ headers: { 'Content-Type': 'application/json' } }
).then(response => {
if (response.data.code === 0) {
success('消息平台已重载')
// 关闭菜单
} else {
error('消息平台重载失败:' + response.data.message)
}
}).catch(error => {
console.error(error)
error('消息平台重载失败:' + error)
})
}
const aboutDialogShow = ref(false) const aboutDialogShow = ref(false)
function showAboutDialog() { function showAboutDialog() {