chore: 整理文件

This commit is contained in:
RockChinQ
2024-01-28 18:45:18 +08:00
parent 2b0faea8ec
commit 698782c537
50 changed files with 29 additions and 1007 deletions

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import typing
from ..core import app, entities as core_entities
from ..openai import entities as llm_entities
from ..openai.session import entities as session_entities
from ..gai import entities as llm_entities
from ..gai.session import entities as session_entities
from . import entities, operator, errors
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update

View File

@@ -6,7 +6,7 @@ import pydantic
import mirai
from ..core import app, entities as core_entities
from ..openai.session import entities as session_entities
from ..gai.session import entities as session_entities
from . import errors, operator

View File

@@ -4,7 +4,7 @@ import typing
import abc
from ..core import app, entities as core_entities
from ..openai.session import entities as session_entities
from ..gai.session import entities as session_entities
from . import entities

View File

@@ -3,11 +3,11 @@ from __future__ import annotations
import logging
import asyncio
from ..qqbot import manager as qqbot_mgr
from ..openai.session import sessionmgr as llm_session_mgr
from ..openai.requester import modelmgr as llm_model_mgr
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
from ..openai.tools import toolmgr as llm_tool_mgr
from ..im import manager as qqbot_mgr
from ..gai.session import sessionmgr as llm_session_mgr
from ..gai.requester import modelmgr as llm_model_mgr
from ..gai.sysprompt import sysprompt as llm_prompt_mgr
from ..gai.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr
from ..database import manager as database_mgr
from ..utils.center import v2 as center_mgr

View File

@@ -14,11 +14,11 @@ from . import controller
from ..pipeline import stagemgr
from ..audit import identifier
from ..database import manager as db_mgr
from ..openai.session import sessionmgr as llm_session_mgr
from ..openai.requester import modelmgr as llm_model_mgr
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
from ..openai.tools import toolmgr as llm_tool_mgr
from ..qqbot import manager as im_mgr
from ..gai.session import sessionmgr as llm_session_mgr
from ..gai.requester import modelmgr as llm_model_mgr
from ..gai.sysprompt import sysprompt as llm_prompt_mgr
from ..gai.tools import toolmgr as llm_tool_mgr
from ..im import manager as im_mgr
from ..command import cmdmgr
from ..plugin import host as plugin_host
from ..utils.center import v2 as center_v2

View File

@@ -1,3 +0,0 @@
"""
数据库操作封装
"""

View File

@@ -1,365 +0,0 @@
"""
数据库管理模块
"""
import hashlib
import json
import logging
import time
import sqlite3
from ..utils import context
class DatabaseManager:
"""封装数据库底层操作,并提供方法给上层使用"""
conn = None
cursor = None
def __init__(self, *args, **kwargs):
self.reconnect()
context.set_database_manager(self)
# 连接到数据库文件
def reconnect(self):
"""连接到数据库"""
self.conn = sqlite3.connect('database.db', check_same_thread=False)
self.cursor = self.conn.cursor()
def close(self):
self.conn.close()
def __execute__(self, *args, **kwargs) -> sqlite3.Cursor:
# logging.debug('SQL: {}'.format(sql))
logging.debug('SQL: {}'.format(args))
c = self.cursor.execute(*args, **kwargs)
self.conn.commit()
return c
# 初始化数据库的函数
def initialize_database(self):
"""创建数据表"""
self.__execute__("""
create table if not exists `sessions` (
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`name` varchar(255) not null,
`type` varchar(255) not null,
`number` bigint not null,
`create_timestamp` bigint not null,
`last_interact_timestamp` bigint not null,
`status` varchar(255) not null default 'on_going',
`default_prompt` text not null default '',
`prompt` text not null,
`token_counts` text not null default '[]'
)
""")
# 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段
self.__execute__("PRAGMA table_info('sessions')")
columns = self.cursor.fetchall()
has_default_prompt = False
has_token_counts = False
for field in columns:
if field[1] == 'default_prompt':
has_default_prompt = True
if field[1] == 'token_counts':
has_token_counts = True
if has_default_prompt and has_token_counts:
break
if not has_default_prompt:
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
if not has_token_counts:
self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'")
self.__execute__("""
create table if not exists `account_fee`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`key_md5` varchar(255) not null,
`timestamp` bigint not null,
`fee` DECIMAL(12,6) not null
)
""")
self.__execute__("""
create table if not exists `account_usage`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`json` text not null
)
""")
# print('Database initialized.')
# session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''):
"""持久化指定session"""
# 检查是否已经有了此name和create_timestamp的session
# 如果有就更新prompt和last_interact_timestamp
# 如果没有,就插入一条新的记录
self.__execute__("""
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(subject_type, subject_number, create_timestamp))
count = self.cursor.fetchone()[0]
if count == 0:
sql = """
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`)
values (?, ?, ?, ?, ?, ?, ?, ?)
"""
self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt, default_prompt, token_counts))
else:
sql = """
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ?
where `type` = ? and `number` = ? and `create_timestamp` = ?
"""
self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type,
subject_number, create_timestamp))
# 显式关闭一个session
def explicit_close_session(self, session_name: str, create_timestamp: int):
self.__execute__("""
update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp))
def set_session_ongoing(self, session_name: str, create_timestamp: int):
self.__execute__("""
update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp))
# 设置session为过期
def set_session_expired(self, session_name: str, create_timestamp: int):
self.__execute__("""
update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp))
# 从数据库加载还没过期的session数据
def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session
config = context.get_config_manager().data
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config['session_expire_time']))
results = self.cursor.fetchall()
sessions = {}
for result in results:
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
# 当且仅当最后一个该对象的会话是on_going状态时才会被加载
if status == 'on_going':
sessions[session_name] = {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
}
else:
if session_name in sessions:
del sessions[session_name]
return sessions
# 获取此session_name前一个session的数据
def last_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1
""".format(session_name, cursor_timestamp))
results = self.cursor.fetchall()
if len(results) == 0:
return None
result = results[0]
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
}
# 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1
""".format(session_name, cursor_timestamp))
results = self.cursor.fetchall()
if len(results) == 0:
return None
result = results[0]
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
}
# 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page))
results = self.cursor.fetchall()
sessions = []
for result in results:
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
sessions.append({
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
})
return sessions
def delete_history(self, session_name: str, index: int) -> bool:
# 删除倒序第index个session
# 查找其id再删除
self.__execute__("""
delete from `sessions` where `id` in (select `id` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit 1 offset {})
""".format(session_name, index))
return self.cursor.rowcount == 1
def delete_all_history(self, session_name: str) -> bool:
self.__execute__("""
delete from `sessions` where `name` = '{}'
""".format(session_name))
return self.cursor.rowcount > 0
def delete_all_session_history(self) -> bool:
self.__execute__("""
delete from `sessions`
""")
return self.cursor.rowcount > 0
# 将apikey的使用量存进数据库
def dump_api_key_usage(self, api_keys: dict, usage: dict):
logging.debug('dumping api key usage...')
logging.debug(api_keys)
logging.debug(usage)
for api_key in api_keys:
# 计算key的md5值
key_md5 = hashlib.md5(api_keys[api_key].encode('utf-8')).hexdigest()
# 获取使用量
usage_count = 0
if key_md5 in usage:
usage_count = usage[key_md5]
# 将使用量存进数据库
# 先检查是否已存在
self.__execute__("""
select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5))
result = self.cursor.fetchone()
if result[0] == 0:
# 不存在则插入
self.__execute__("""
insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {})
""".format(key_md5, usage_count, int(time.time())))
else:
# 存在则更新timestamp设置为当前
self.__execute__("""
update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}'
""".format(usage_count, int(time.time()), key_md5))
def load_api_key_usage(self):
self.__execute__("""
select `key_md5`, `usage` from `api_key_usage`
""")
results = self.cursor.fetchall()
usage = {}
for result in results:
key_md5 = result[0]
usage_count = result[1]
usage[key_md5] = usage_count
return usage
def dump_usage_json(self, usage: dict):
json_str = json.dumps(usage)
self.__execute__("""
select count(*) from `account_usage`""")
result = self.cursor.fetchone()
if result[0] == 0:
# 不存在则插入
self.__execute__("""
insert into `account_usage` (`json`) values ('{}')
""".format(json_str))
else:
# 存在则更新
self.__execute__("""
update `account_usage` set `json` = '{}' where `id` = 1
""".format(json_str))
def load_usage_json(self):
self.__execute__("""
select `json` from `account_usage` order by id desc limit 1
""")
result = self.cursor.fetchone()
if result is None:
return None
else:
return result[0]

View File

@@ -8,9 +8,9 @@ Completion - text-davinci-003 等模型
import tiktoken
import openai
from ..openai.api import model as api_model
from ..openai.api import completion as api_completion
from ..openai.api import chat_completion as api_chat_completion
from ..gai.api import model as api_model
from ..gai.api import completion as api_completion
from ..gai.api import chat_completion as api_chat_completion
COMPLETION_MODELS = {
"gpt-3.5-turbo-instruct",

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import typing
import pydantic
from ...openai import entities
from ...gai import entities
class Prompt(pydantic.BaseModel):

View File

@@ -5,7 +5,7 @@ import os
from .. import loader
from .. import entities
from ....openai import entities as llm_entities
from ....gai import entities as llm_entities
class ScenarioPromptLoader(loader.PromptLoader):

View File

@@ -3,7 +3,7 @@ import os
from .. import loader
from .. import entities
from ....openai import entities as llm_entities
from ....gai import entities as llm_entities
class SingleSystemPromptLoader(loader.PromptLoader):

View File

@@ -10,11 +10,11 @@ from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
import mirai
import func_timeout
from ..openai import session as openai_session
from ..gai import session as openai_session
from ..utils import context
import tips as tips_custom
from ..qqbot import adapter as msadapter
from ..im import adapter as msadapter
from .ratelim import ratelim
from ..core import app, entities as core_entities
@@ -44,13 +44,13 @@ class QQBotManager:
logging.debug("Use adapter:" + config['msg_source_adapter'])
if config['msg_source_adapter'] == 'yirimirai':
from pkg.qqbot.sources.yirimirai import YiriMiraiAdapter
from pkg.im.sources.yirimirai import YiriMiraiAdapter
mirai_http_api_config = config['mirai_http_api_config']
self.bot_account_id = config['mirai_http_api_config']['qq']
self.adapter = YiriMiraiAdapter(mirai_http_api_config)
elif config['msg_source_adapter'] == 'nakuru':
from pkg.qqbot.sources.nakuru import NakuruProjectAdapter
from pkg.im.sources.nakuru import NakuruProjectAdapter
self.adapter = NakuruProjectAdapter(config['nakuru_config'])
self.bot_account_id = self.adapter.bot_account_id

View File

@@ -9,7 +9,7 @@ import nakuru
import nakuru.entities.components as nkc
from .. import adapter as adapter_model
from ...qqbot import blob
from ...im import blob
from ...utils import context

View File

@@ -7,7 +7,7 @@ import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....openai import entities as llm_entities
from ....gai import entities as llm_entities
class ChatMessageHandler(handler.MessageHandler):

View File

@@ -15,7 +15,7 @@ from ..utils import network as network
from ..utils import context as context
from ..plugin import switch as switch
from ..plugin import settings as settings
from ..qqbot import adapter as msadapter
from ..im import adapter as msadapter
from ..plugin import metadata as metadata
from mirai import Mirai

View File

@@ -1,128 +0,0 @@
from __future__ import annotations
import threading
from ..database import manager as db_mgr
from ..qqbot import manager as qqbot_mgr
from ..config import manager as config_mgr
from ..plugin import host as plugin_host
from .center import v2 as center_v2
context = {
'inst': {
'database.manager.DatabaseManager': None,
'openai.manager.OpenAIInteract': None,
'qqbot.manager.QQBotManager': None,
'config.manager.ConfigManager': None,
},
'pool_ctl': None,
'logger_handler': None,
'config': None,
'plugin_host': None,
}
context_lock = threading.Lock()
### context耦合度非常高需要大改 ###
def set_config(inst):
context_lock.acquire()
context['config'] = inst
context_lock.release()
def get_config():
context_lock.acquire()
t = context['config']
context_lock.release()
return t
def set_database_manager(inst: db_mgr.DatabaseManager):
context_lock.acquire()
context['inst']['database.manager.DatabaseManager'] = inst
context_lock.release()
def get_database_manager() -> db_mgr.DatabaseManager:
context_lock.acquire()
t = context['inst']['database.manager.DatabaseManager']
context_lock.release()
return t
def set_openai_manager(inst: openai_mgr.OpenAIInteract):
context_lock.acquire()
context['inst']['openai.manager.OpenAIInteract'] = inst
context_lock.release()
def get_openai_manager() -> openai_mgr.OpenAIInteract:
context_lock.acquire()
t = context['inst']['openai.manager.OpenAIInteract']
context_lock.release()
return t
def set_qqbot_manager(inst: qqbot_mgr.QQBotManager):
context_lock.acquire()
context['inst']['qqbot.manager.QQBotManager'] = inst
context_lock.release()
def get_qqbot_manager() -> qqbot_mgr.QQBotManager:
context_lock.acquire()
t = context['inst']['qqbot.manager.QQBotManager']
context_lock.release()
return t
def set_config_manager(inst: config_mgr.ConfigManager):
context_lock.acquire()
context['inst']['config.manager.ConfigManager'] = inst
context_lock.release()
def get_config_manager() -> config_mgr.ConfigManager:
context_lock.acquire()
t = context['inst']['config.manager.ConfigManager']
context_lock.release()
return t
def set_plugin_host(inst: plugin_host.PluginHost):
context_lock.acquire()
context['plugin_host'] = inst
context_lock.release()
def get_plugin_host() -> plugin_host.PluginHost:
context_lock.acquire()
t = context['plugin_host']
context_lock.release()
return t
def set_thread_ctl(inst: threadctl.ThreadCtl):
context_lock.acquire()
context['pool_ctl'] = inst
context_lock.release()
def get_thread_ctl() -> threadctl.ThreadCtl:
context_lock.acquire()
t: threadctl.ThreadCtl = context['pool_ctl']
context_lock.release()
return t
def set_center_v2_api(inst: center_v2.V2CenterAPI):
context_lock.acquire()
context['center_v2_api'] = inst
context_lock.release()
def get_center_v2_api() -> center_v2.V2CenterAPI:
context_lock.acquire()
t: center_v2.V2CenterAPI = context['center_v2_api']
context_lock.release()
return t