perf: 为 context.py 中的方法添加类型提示

This commit is contained in:
RockChinQ
2023-11-26 17:33:13 +08:00
parent 9fccf84987
commit 7708eaa82c

View File

@@ -1,6 +1,13 @@
from __future__ import annotations
import threading
from . import threadctl
from ..database import manager as db_mgr
from ..openai import manager as openai_mgr
from ..qqbot import manager as qqbot_mgr
from ..plugin import host as plugin_host
context = {
'inst': {
@@ -29,59 +36,59 @@ def get_config():
return t
def set_database_manager(inst):
def set_database_manager(inst: db_mgr.DatabaseManager):
context_lock.acquire()
context['inst']['database.manager.DatabaseManager'] = inst
context_lock.release()
def get_database_manager():
def get_database_manager() -> db_mgr.DatabaseManager:
context_lock.acquire()
t = context['inst']['database.manager.DatabaseManager']
context_lock.release()
return t
def set_openai_manager(inst):
def set_openai_manager(inst: openai_mgr.OpenAIInteract):
context_lock.acquire()
context['inst']['openai.manager.OpenAIInteract'] = inst
context_lock.release()
def get_openai_manager():
def get_openai_manager() -> openai_mgr.OpenAIInteract:
context_lock.acquire()
t = context['inst']['openai.manager.OpenAIInteract']
context_lock.release()
return t
def set_qqbot_manager(inst):
def set_qqbot_manager(inst: qqbot_mgr.QQBotManager):
context_lock.acquire()
context['inst']['qqbot.manager.QQBotManager'] = inst
context_lock.release()
def get_qqbot_manager():
def get_qqbot_manager() -> qqbot_mgr.QQBotManager:
context_lock.acquire()
t = context['inst']['qqbot.manager.QQBotManager']
context_lock.release()
return t
def set_plugin_host(inst):
def set_plugin_host(inst: plugin_host.PluginHost):
context_lock.acquire()
context['plugin_host'] = inst
context_lock.release()
def get_plugin_host():
def get_plugin_host() -> plugin_host.PluginHost:
context_lock.acquire()
t = context['plugin_host']
context_lock.release()
return t
def set_thread_ctl(inst):
def set_thread_ctl(inst: threadctl.ThreadCtl):
context_lock.acquire()
context['pool_ctl'] = inst
context_lock.release()