diff --git a/.gitignore b/.gitignore index 2869b7cc..db62bdca 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,5 @@ botpy.log* test.py /web_ui .venv/ -uv.lock \ No newline at end of file +uv.lock +/test \ No newline at end of file diff --git a/libs/wechatpad_api/__init__.py b/libs/wechatpad_api/__init__.py index 23c23fb2..9ac533f7 100644 --- a/libs/wechatpad_api/__init__.py +++ b/libs/wechatpad_api/__init__.py @@ -1 +1 @@ -from .client import WeChatPadClient \ No newline at end of file +from .client import WeChatPadClient as WeChatPadClient diff --git a/libs/wechatpad_api/api/chatroom.py b/libs/wechatpad_api/api/chatroom.py index a7af207c..2d9281a2 100644 --- a/libs/wechatpad_api/api/chatroom.py +++ b/libs/wechatpad_api/api/chatroom.py @@ -1,4 +1,4 @@ -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json class ChatRoomApi: @@ -7,8 +7,6 @@ class ChatRoomApi: self.token = token def get_chatroom_member_detail(self, chatroom_name): - params = { - "ChatRoomName": chatroom_name - } + params = {'ChatRoomName': chatroom_name} url = self.base_url + '/group/GetChatroomMemberDetail' return post_json(url, token=self.token, data=params) diff --git a/libs/wechatpad_api/api/downloadpai.py b/libs/wechatpad_api/api/downloadpai.py index a82a5674..2d45fac6 100644 --- a/libs/wechatpad_api/api/downloadpai.py +++ b/libs/wechatpad_api/api/downloadpai.py @@ -1,32 +1,23 @@ -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json import httpx import base64 + class DownloadApi: def __init__(self, base_url, token): self.base_url = base_url self.token = token def send_download(self, aeskey, file_type, file_url): - json_data = { - "AesKey": aeskey, - "FileType": file_type, - "FileURL": file_url - } - url = self.base_url + "/message/SendCdnDownload" + json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url} + url = self.base_url + '/message/SendCdnDownload' return post_json(url, token=self.token, data=json_data) - def get_msg_voice(self,buf_id, length, new_msgid): - json_data = { - "Bufid": buf_id, - "Length": length, - "NewMsgId": new_msgid, - "ToUserName": "" - } - url = self.base_url + "/message/GetMsgVoice" + def get_msg_voice(self, buf_id, length, new_msgid): + json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''} + url = self.base_url + '/message/GetMsgVoice' return post_json(url, token=self.token, data=json_data) - async def download_url_to_base64(self, download_url): async with httpx.AsyncClient() as client: response = await client.get(download_url) @@ -36,4 +27,4 @@ class DownloadApi: base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式 return base64_str else: - raise Exception('获取文件失败') \ No newline at end of file + raise Exception('获取文件失败') diff --git a/libs/wechatpad_api/api/friend.py b/libs/wechatpad_api/api/friend.py index 00701a5d..a7a448aa 100644 --- a/libs/wechatpad_api/api/friend.py +++ b/libs/wechatpad_api/api/friend.py @@ -1,11 +1,6 @@ -from libs.wechatpad_api.util.http_util import post_json,async_request -from typing import List, Dict, Any, Optional - - class FriendApi: """联系人API类,处理所有与联系人相关的操作""" def __init__(self, base_url: str, token: str): self.base_url = base_url self.token = token - diff --git a/libs/wechatpad_api/api/login.py b/libs/wechatpad_api/api/login.py index 142a3c85..4aa4ae8d 100644 --- a/libs/wechatpad_api/api/login.py +++ b/libs/wechatpad_api/api/login.py @@ -1,37 +1,34 @@ -from libs.wechatpad_api.util.http_util import async_request,post_json,get_json +from libs.wechatpad_api.util.http_util import post_json, get_json class LoginApi: def __init__(self, base_url: str, token: str = None, admin_key: str = None): - ''' + """ Args: base_url: 原始路径 token: token admin_key: 管理员key - ''' + """ self.base_url = base_url self.token = token # self.admin_key = admin_key - def get_token(self, admin_key, day: int=365): + def get_token(self, admin_key, day: int = 365): # 获取普通token - url = f"{self.base_url}/admin/GenAuthKey1" - json_data = { - "Count": 1, - "Days": day - } + url = f'{self.base_url}/admin/GenAuthKey1' + json_data = {'Count': 1, 'Days': day} return post_json(base_url=url, token=admin_key, data=json_data) - def get_login_qr(self, Proxy: str = ""): - ''' + def get_login_qr(self, Proxy: str = ''): + """ Args: Proxy:异地使用时代理 Returns:json数据 - ''' + """ """ { @@ -49,54 +46,37 @@ class LoginApi: } """ - #获取登录二维码 - url = f"{self.base_url}/login/GetLoginQrCodeNew" + # 获取登录二维码 + url = f'{self.base_url}/login/GetLoginQrCodeNew' check = False - if Proxy != "": + if Proxy != '': check = True - json_data = { - "Check": check, - "Proxy": Proxy - } + json_data = {'Check': check, 'Proxy': Proxy} return post_json(base_url=url, token=self.token, data=json_data) - def get_login_status(self): # 获取登录状态 url = f'{self.base_url}/login/GetLoginStatus' return get_json(base_url=url, token=self.token) - - def logout(self): # 退出登录 url = f'{self.base_url}/login/LogOut' return post_json(base_url=url, token=self.token) - - - - def wake_up_login(self, Proxy: str = ""): + def wake_up_login(self, Proxy: str = ''): # 唤醒登录 url = f'{self.base_url}/login/WakeUpLogin' check = False - if Proxy != "": + if Proxy != '': check = True - json_data = { - "Check": check, - "Proxy": "" - } + json_data = {'Check': check, 'Proxy': ''} return post_json(base_url=url, token=self.token, data=json_data) - - - def login(self,admin_key): + def login(self, admin_key): login_status = self.get_login_status() - if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信": - print("token已经失效,重新获取") + if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信': + print('token已经失效,重新获取') token_data = self.get_token(admin_key) - self.token = token_data["Data"][0] - - - + self.token = token_data['Data'][0] diff --git a/libs/wechatpad_api/api/message.py b/libs/wechatpad_api/api/message.py index 2089ce96..cca76313 100644 --- a/libs/wechatpad_api/api/message.py +++ b/libs/wechatpad_api/api/message.py @@ -1,5 +1,4 @@ - -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json class MessageApi: @@ -7,8 +6,8 @@ class MessageApi: self.base_url = base_url self.token = token - def post_text(self, to_wxid, content, ats: list= []): - ''' + def post_text(self, to_wxid, content, ats: list = []): + """ Args: app_id: 微信id @@ -18,106 +17,64 @@ class MessageApi: Returns: - ''' - url = self.base_url + "/message/SendTextMessage" + """ + url = self.base_url + '/message/SendTextMessage' """发送文字消息""" json_data = { - "MsgItem": [ - { - "AtWxIDList": ats, - "ImageContent": "", - "MsgType": 0, - "TextContent": content, - "ToUserName": to_wxid - } - ] - } - return post_json(base_url=url, token=self.token, data=json_data) + 'MsgItem': [ + {'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid} + ] + } + return post_json(base_url=url, token=self.token, data=json_data) - - - - def post_image(self, to_wxid, img_url, ats: list= []): + def post_image(self, to_wxid, img_url, ats: list = []): """发送图片消息""" # 这里好像可以尝试发送多个暂时未测试 json_data = { - "MsgItem": [ - { - "AtWxIDList": ats, - "ImageContent": img_url, - "MsgType": 0, - "TextContent": '', - "ToUserName": to_wxid - } + 'MsgItem': [ + {'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid} ] } - url = self.base_url + "/message/SendImageMessage" + url = self.base_url + '/message/SendImageMessage' return post_json(base_url=url, token=self.token, data=json_data) def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration): """发送语音消息""" json_data = { - "ToUserName": to_wxid, - "VoiceData": voice_data, - "VoiceFormat": voice_forma, - "VoiceSecond": voice_duration + 'ToUserName': to_wxid, + 'VoiceData': voice_data, + 'VoiceFormat': voice_forma, + 'VoiceSecond': voice_duration, } - url = self.base_url + "/message/SendVoice" + url = self.base_url + '/message/SendVoice' return post_json(base_url=url, token=self.token, data=json_data) - - - - def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag): """发送名片消息""" param = { - "CardAlias": alias, - "CardFlag": flag, - "CardNickName": nick_name, - "CardWxId": name_card_wxid, - "ToUserName": to_wxid + 'CardAlias': alias, + 'CardFlag': flag, + 'CardNickName': nick_name, + 'CardWxId': name_card_wxid, + 'ToUserName': to_wxid, } - url = f"{self.base_url}/message/ShareCardMessage" + url = f'{self.base_url}/message/ShareCardMessage' return post_json(base_url=url, token=self.token, data=param) - def post_emoji(self, to_wxid, emoji_md5, emoji_size:int=0): + def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0): """发送emoji消息""" - json_data = { - "EmojiList": [ - { - "EmojiMd5": emoji_md5, - "EmojiSize": emoji_size, - "ToUserName": to_wxid - } - ] - } - url = f"{self.base_url}/message/SendEmojiMessage" + json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]} + url = f'{self.base_url}/message/SendEmojiMessage' return post_json(base_url=url, token=self.token, data=json_data) - def post_app_msg(self, to_wxid,xml_data, contenttype:int=0): + def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0): """发送appmsg消息""" - json_data = { - "AppList": [ - { - "ContentType": contenttype, - "ContentXML": xml_data, - "ToUserName": to_wxid - } - ] - } - url = f"{self.base_url}/message/SendAppMessage" + json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]} + url = f'{self.base_url}/message/SendAppMessage' return post_json(base_url=url, token=self.token, data=json_data) - - def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time): """撤回消息""" - param = { - "ClientMsgId": msg_id, - "CreateTime": create_time, - "NewMsgId": new_msg_id, - "ToUserName": to_wxid - } - url = f"{self.base_url}/message/RevokeMsg" - return post_json(base_url=url, token=self.token, data=param) \ No newline at end of file + param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid} + url = f'{self.base_url}/message/RevokeMsg' + return post_json(base_url=url, token=self.token, data=param) diff --git a/libs/wechatpad_api/util/http_util.py b/libs/wechatpad_api/util/http_util.py index 754003e9..447c29df 100644 --- a/libs/wechatpad_api/util/http_util.py +++ b/libs/wechatpad_api/util/http_util.py @@ -1,10 +1,9 @@ import requests +import aiohttp + def post_json(base_url, token, data=None): - headers = { - 'Content-Type': 'application/json' - } - + headers = {'Content-Type': 'application/json'} url = base_url + f'?key={token}' @@ -18,14 +17,12 @@ def post_json(base_url, token, data=None): else: raise RuntimeError(response.text) except Exception as e: - print(f"http请求失败, url={url}, exception={e}") + print(f'http请求失败, url={url}, exception={e}') raise RuntimeError(str(e)) -def get_json(base_url, token): - headers = { - 'Content-Type': 'application/json' - } +def get_json(base_url, token): + headers = {'Content-Type': 'application/json'} url = base_url + f'?key={token}' @@ -39,21 +36,18 @@ def get_json(base_url, token): else: raise RuntimeError(response.text) except Exception as e: - print(f"http请求失败, url={url}, exception={e}") + print(f'http请求失败, url={url}, exception={e}') raise RuntimeError(str(e)) -import aiohttp -import asyncio - async def async_request( - base_url: str, - token_key: str, - method: str = 'POST', - params: dict = None, - # headers: dict = None, - data: dict = None, - json: dict = None + base_url: str, + token_key: str, + method: str = 'POST', + params: dict = None, + # headers: dict = None, + data: dict = None, + json: dict = None, ): """ 通用异步请求函数 @@ -67,18 +61,11 @@ async def async_request( :param json: JSON数据 :return: 响应文本 """ - headers = { - 'Content-Type': 'application/json' - } - url = f"{base_url}?key={token_key}" + headers = {'Content-Type': 'application/json'} + url = f'{base_url}?key={token_key}' async with aiohttp.ClientSession() as session: async with session.request( - method=method, - url=url, - params=params, - headers=headers, - data=data, - json=json + method=method, url=url, params=params, headers=headers, data=data, json=json ) as response: response.raise_for_status() # 如果状态码不是200,抛出异常 result = await response.json() @@ -89,4 +76,3 @@ async def async_request( # return await result # else: # raise RuntimeError("请求失败",response.text) - diff --git a/pkg/api/http/controller/group.py b/pkg/api/http/controller/group.py index 73780208..16fa1df1 100644 --- a/pkg/api/http/controller/group.py +++ b/pkg/api/http/controller/group.py @@ -14,8 +14,8 @@ preregistered_groups: list[type[RouterGroup]] = [] """Pre-registered list of RouterGroup""" -def group_class(name: str, path: str) -> None: - """Register a RouterGroup""" +def group_class(name: str, path: str) -> typing.Callable[[typing.Type[RouterGroup]], typing.Type[RouterGroup]]: + """注册一个 RouterGroup""" def decorator(cls: typing.Type[RouterGroup]) -> typing.Type[RouterGroup]: cls.name = name @@ -86,10 +86,11 @@ class RouterGroup(abc.ABC): try: return await f(*args, **kwargs) - except Exception: # auto 500 + + except Exception as e: # 自动 500 traceback.print_exc() # return self.http_status(500, -2, str(e)) - return self.http_status(500, -2, 'internal server error') + return self.http_status(500, -2, str(e)) new_f = handler_error new_f.__name__ = (self.name + rule).replace('/', '__') @@ -120,6 +121,6 @@ class RouterGroup(abc.ABC): } ) - def http_status(self, status: int, code: int, msg: str) -> quart.Response: - """Return a response with a specified status code""" - return self.fail(code, msg), status + def http_status(self, status: int, code: int, msg: str) -> typing.Tuple[quart.Response, int]: + """返回一个指定状态码的响应""" + return (self.fail(code, msg), status) \ No newline at end of file diff --git a/pkg/api/http/controller/groups/files.py b/pkg/api/http/controller/groups/files.py index 0a8b2210..b3c1a3f1 100644 --- a/pkg/api/http/controller/groups/files.py +++ b/pkg/api/http/controller/groups/files.py @@ -2,6 +2,10 @@ from __future__ import annotations import quart import mimetypes +import uuid +import asyncio + +import quart.datastructures from .. import group @@ -20,3 +24,23 @@ class FilesRouterGroup(group.RouterGroup): mime_type = 'image/jpeg' return quart.Response(image_bytes, mimetype=mime_type) + + @self.route('/documents', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) + async def _() -> quart.Response: + request = quart.request + # get file bytes from 'file' + file = (await request.files)['file'] + assert isinstance(file, quart.datastructures.FileStorage) + + file_bytes = await asyncio.to_thread(file.stream.read) + extension = file.filename.split('.')[-1] + file_name = file.filename.split('.')[0] + + file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension + # save file to storage + await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes) + return self.success( + data={ + 'file_id': file_key, + } + ) diff --git a/pkg/api/http/controller/groups/knowledge/__init__.py b/pkg/api/http/controller/groups/knowledge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py new file mode 100644 index 00000000..a5bed5df --- /dev/null +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -0,0 +1,90 @@ +import quart +from ... import group + + +@group.group_class('knowledge_base', '/api/v1/knowledge/bases') +class KnowledgeBaseRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('', methods=['POST', 'GET']) + async def handle_knowledge_bases() -> quart.Response: + if quart.request.method == 'GET': + knowledge_bases = await self.ap.knowledge_service.get_knowledge_bases() + return self.success(data={'bases': knowledge_bases}) + + elif quart.request.method == 'POST': + json_data = await quart.request.json + knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data) + return self.success(data={'uuid': knowledge_base_uuid}) + + return self.http_status(405, -1, 'Method not allowed') + + @self.route( + '/', + methods=['GET', 'DELETE', 'PUT'], + ) + async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> quart.Response: + if quart.request.method == 'GET': + knowledge_base = await self.ap.knowledge_service.get_knowledge_base(knowledge_base_uuid) + + if knowledge_base is None: + return self.http_status(404, -1, 'knowledge base not found') + + return self.success( + data={ + 'base': knowledge_base, + } + ) + + elif quart.request.method == 'PUT': + json_data = await quart.request.json + await self.ap.knowledge_service.update_knowledge_base(knowledge_base_uuid, json_data) + return self.success({}) + + elif quart.request.method == 'DELETE': + await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid) + return self.success({}) + + @self.route( + '//files', + methods=['GET', 'POST'], + ) + async def get_knowledge_base_files(knowledge_base_uuid: str) -> str: + if quart.request.method == 'GET': + files = await self.ap.knowledge_service.get_files_by_knowledge_base(knowledge_base_uuid) + return self.success( + data={ + 'files': files, + } + ) + + elif quart.request.method == 'POST': + json_data = await quart.request.json + file_id = json_data.get('file_id') + if not file_id: + return self.http_status(400, -1, 'File ID is required') + + # 调用服务层方法将文件与知识库关联 + task_id = await self.ap.knowledge_service.store_file(knowledge_base_uuid, file_id) + return self.success( + { + 'task_id': task_id, + } + ) + + @self.route( + '//files/', + methods=['DELETE'], + ) + async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str: + await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id) + return self.success({}) + + @self.route( + '//retrieve', + methods=['POST'], + ) + async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str: + json_data = await quart.request.json + query = json_data.get('query') + results = await self.ap.knowledge_service.retrieve_knowledge_base(knowledge_base_uuid, query) + return self.success(data={'results': results}) diff --git a/pkg/api/http/controller/groups/provider/models.py b/pkg/api/http/controller/groups/provider/models.py index bb77986c..0de0c922 100644 --- a/pkg/api/http/controller/groups/provider/models.py +++ b/pkg/api/http/controller/groups/provider/models.py @@ -9,18 +9,18 @@ class LLMModelsRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success(data={'models': await self.ap.model_service.get_llm_models()}) + return self.success(data={'models': await self.ap.llm_model_service.get_llm_models()}) elif quart.request.method == 'POST': json_data = await quart.request.json - model_uuid = await self.ap.model_service.create_llm_model(json_data) + model_uuid = await self.ap.llm_model_service.create_llm_model(json_data) return self.success(data={'uuid': model_uuid}) @self.route('/', methods=['GET', 'PUT', 'DELETE']) async def _(model_uuid: str) -> str: if quart.request.method == 'GET': - model = await self.ap.model_service.get_llm_model(model_uuid) + model = await self.ap.llm_model_service.get_llm_model(model_uuid) if model is None: return self.http_status(404, -1, 'model not found') @@ -29,11 +29,11 @@ class LLMModelsRouterGroup(group.RouterGroup): elif quart.request.method == 'PUT': json_data = await quart.request.json - await self.ap.model_service.update_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.update_llm_model(model_uuid, json_data) return self.success() elif quart.request.method == 'DELETE': - await self.ap.model_service.delete_llm_model(model_uuid) + await self.ap.llm_model_service.delete_llm_model(model_uuid) return self.success() @@ -41,6 +41,49 @@ class LLMModelsRouterGroup(group.RouterGroup): async def _(model_uuid: str) -> str: json_data = await quart.request.json - await self.ap.model_service.test_llm_model(model_uuid, json_data) + await self.ap.llm_model_service.test_llm_model(model_uuid, json_data) + + return self.success() + + +@group.group_class('models/embedding', '/api/v1/provider/models/embedding') +class EmbeddingModelsRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('', methods=['GET', 'POST']) + async def _() -> str: + if quart.request.method == 'GET': + return self.success(data={'models': await self.ap.embedding_models_service.get_embedding_models()}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + + model_uuid = await self.ap.embedding_models_service.create_embedding_model(json_data) + + return self.success(data={'uuid': model_uuid}) + + @self.route('/', methods=['GET', 'PUT', 'DELETE']) + async def _(model_uuid: str) -> str: + if quart.request.method == 'GET': + model = await self.ap.embedding_models_service.get_embedding_model(model_uuid) + + if model is None: + return self.http_status(404, -1, 'model not found') + + return self.success(data={'model': model}) + elif quart.request.method == 'PUT': + json_data = await quart.request.json + + await self.ap.embedding_models_service.update_embedding_model(model_uuid, json_data) + + return self.success() + elif quart.request.method == 'DELETE': + await self.ap.embedding_models_service.delete_embedding_model(model_uuid) + + return self.success() + + @self.route('//test', methods=['POST']) + async def _(model_uuid: str) -> str: + json_data = await quart.request.json + + await self.ap.embedding_models_service.test_embedding_model(model_uuid, json_data) return self.success() diff --git a/pkg/api/http/controller/groups/provider/requesters.py b/pkg/api/http/controller/groups/provider/requesters.py index 0f999288..af9e1540 100644 --- a/pkg/api/http/controller/groups/provider/requesters.py +++ b/pkg/api/http/controller/groups/provider/requesters.py @@ -8,7 +8,8 @@ class RequestersRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> quart.Response: - return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()}) + model_type = quart.request.args.get('type', '') + return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info(model_type)}) @self.route('/', methods=['GET']) async def _(requester_name: str) -> quart.Response: diff --git a/pkg/api/http/controller/main.py b/pkg/api/http/controller/main.py index 0191ead5..e45b461d 100644 --- a/pkg/api/http/controller/main.py +++ b/pkg/api/http/controller/main.py @@ -14,11 +14,13 @@ from . import group from .groups import provider as groups_provider from .groups import platform as groups_platform from .groups import pipelines as groups_pipelines +from .groups import knowledge as groups_knowledge importutil.import_modules_in_pkg(groups) importutil.import_modules_in_pkg(groups_provider) importutil.import_modules_in_pkg(groups_platform) importutil.import_modules_in_pkg(groups_pipelines) +importutil.import_modules_in_pkg(groups_knowledge) class HTTPController: diff --git a/pkg/api/http/service/knowledge.py b/pkg/api/http/service/knowledge.py new file mode 100644 index 00000000..27506ec9 --- /dev/null +++ b/pkg/api/http/service/knowledge.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import uuid +import sqlalchemy + +from ....core import app +from ....entity.persistence import rag as persistence_rag + + +class KnowledgeService: + """知识库服务""" + + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_knowledge_bases(self) -> list[dict]: + """获取所有知识库""" + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase)) + knowledge_bases = result.all() + return [ + self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base) + for knowledge_base in knowledge_bases + ] + + async def get_knowledge_base(self, kb_uuid: str) -> dict | None: + """获取知识库""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid) + ) + knowledge_base = result.first() + if knowledge_base is None: + return None + return self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base) + + async def create_knowledge_base(self, kb_data: dict) -> str: + """创建知识库""" + kb_data['uuid'] = str(uuid.uuid4()) + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data)) + + kb = await self.get_knowledge_base(kb_data['uuid']) + + await self.ap.rag_mgr.load_knowledge_base(kb) + + return kb_data['uuid'] + + async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None: + """更新知识库""" + if 'uuid' in kb_data: + del kb_data['uuid'] + + if 'embedding_model_uuid' in kb_data: + del kb_data['embedding_model_uuid'] + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.KnowledgeBase) + .values(kb_data) + .where(persistence_rag.KnowledgeBase.uuid == kb_uuid) + ) + await self.ap.rag_mgr.remove_knowledge_base_from_runtime(kb_uuid) + + kb = await self.get_knowledge_base(kb_uuid) + + await self.ap.rag_mgr.load_knowledge_base(kb) + + async def store_file(self, kb_uuid: str, file_id: str) -> int: + """存储文件""" + # await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(kb_id=kb_uuid, file_id=file_id)) + # await self.ap.rag_mgr.store_file(file_id) + runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + if runtime_kb is None: + raise Exception('Knowledge base not found') + return await runtime_kb.store_file(file_id) + + async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]: + """检索知识库""" + runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + if runtime_kb is None: + raise Exception('Knowledge base not found') + return [result.model_dump() for result in await runtime_kb.retrieve(query)] + + async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]: + """获取知识库文件""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid) + ) + files = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_rag.File, file) for file in files] + + async def delete_file(self, kb_uuid: str, file_id: str) -> None: + """删除文件""" + runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + if runtime_kb is None: + raise Exception('Knowledge base not found') + await runtime_kb.delete_file(file_id) + + async def delete_knowledge_base(self, kb_uuid: str) -> None: + """删除知识库""" + await self.ap.rag_mgr.delete_knowledge_base(kb_uuid) + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid) + ) + + # delete files + files = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid) + ) + for file in files: + # delete chunks + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file.uuid) + ) + # delete file + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file.uuid) + ) diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 74fb4e02..d8457da3 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -10,7 +10,7 @@ from ....provider.modelmgr import requester as model_requester from ....provider import entities as llm_entities -class ModelsService: +class LLMModelsService: ap: app.Application def __init__(self, ap: app.Application) -> None: @@ -103,3 +103,89 @@ class ModelsService: funcs=[], extra_args={}, ) + + +class EmbeddingModelsService: + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_embedding_models(self) -> list[dict]: + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + + models = result.all() + return [self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) for model in models] + + async def create_embedding_model(self, model_data: dict) -> str: + model_data['uuid'] = str(uuid.uuid4()) + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(persistence_model.EmbeddingModel).values(**model_data) + ) + + embedding_model = await self.get_embedding_model(model_data['uuid']) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + return model_data['uuid'] + + async def get_embedding_model(self, model_uuid: str) -> dict | None: + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + model = result.first() + + if model is None: + return None + + return self.ap.persistence_mgr.serialize_model(persistence_model.EmbeddingModel, model) + + async def update_embedding_model(self, model_uuid: str, model_data: dict) -> None: + if 'uuid' in model_data: + del model_data['uuid'] + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_model.EmbeddingModel) + .where(persistence_model.EmbeddingModel.uuid == model_uuid) + .values(**model_data) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + embedding_model = await self.get_embedding_model(model_uuid) + + await self.ap.model_mgr.load_embedding_model(embedding_model) + + async def delete_embedding_model(self, model_uuid: str) -> None: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_model.EmbeddingModel).where( + persistence_model.EmbeddingModel.uuid == model_uuid + ) + ) + + await self.ap.model_mgr.remove_embedding_model(model_uuid) + + async def test_embedding_model(self, model_uuid: str, model_data: dict) -> None: + runtime_embedding_model: model_requester.RuntimeEmbeddingModel | None = None + + if model_uuid != '_': + for model in self.ap.model_mgr.embedding_models: + if model.model_entity.uuid == model_uuid: + runtime_embedding_model = model + break + + if runtime_embedding_model is None: + raise Exception('model not found') + + else: + runtime_embedding_model = await self.ap.model_mgr.init_runtime_embedding_model(model_data) + + await runtime_embedding_model.requester.invoke_embedding( + model=runtime_embedding_model, + input_text=['Hello, world!'], + extra_args={}, + ) diff --git a/pkg/core/app.py b/pkg/core/app.py index 23ce2759..21816cfc 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -22,11 +22,14 @@ from ..api.http.service import user as user_service from ..api.http.service import model as model_service from ..api.http.service import pipeline as pipeline_service from ..api.http.service import bot as bot_service +from ..api.http.service import knowledge as knowledge_service from ..discover import engine as discover_engine from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities +from ..rag.knowledge import kbmgr as rag_mgr +from ..vector import mgr as vectordb_mgr class Application: @@ -47,6 +50,8 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None + rag_mgr: rag_mgr.RAGManager = None + # TODO move to pipeline tool_mgr: llm_tool_mgr.ToolManager = None @@ -93,6 +98,8 @@ class Application: persistence_mgr: persistencemgr.PersistenceManager = None + vector_db_mgr: vectordb_mgr.VectorDBManager = None + http_ctrl: http_controller.HTTPController = None log_cache: logcache.LogCache = None @@ -103,12 +110,16 @@ class Application: user_service: user_service.UserService = None - model_service: model_service.ModelsService = None + llm_model_service: model_service.LLMModelsService = None + + embedding_models_service: model_service.EmbeddingModelsService = None pipeline_service: pipeline_service.PipelineService = None bot_service: bot_service.BotService = None + knowledge_service: knowledge_service.KnowledgeService = None + def __init__(self): pass @@ -143,6 +154,7 @@ class Application: name='http-api-controller', scopes=[core_entities.LifecycleControlScope.APPLICATION], ) + self.task_mgr.create_task( never_ending(), name='never-ending-task', diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 4caf18ed..8dc51e5b 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -19,7 +19,7 @@ class LifecycleControlScope(enum.Enum): APPLICATION = 'application' PLATFORM = 'platform' PLUGIN = 'plugin' - PROVIDER = 'provider' + PROVIDER = 'provider' class LauncherTypes(enum.Enum): diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index d4b443cf..0f28f0c8 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -9,6 +9,7 @@ from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.tools import toolmgr as llm_tool_mgr +from ...rag.knowledge import kbmgr as rag_mgr from ...platform import botmgr as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -16,9 +17,11 @@ from ...api.http.service import user as user_service from ...api.http.service import model as model_service from ...api.http.service import pipeline as pipeline_service from ...api.http.service import bot as bot_service +from ...api.http.service import knowledge as knowledge_service from ...discover import engine as discover_engine from ...storage import mgr as storagemgr from ...utils import logcache +from ...vector import mgr as vectordb_mgr from .. import taskmgr @@ -88,6 +91,15 @@ class BuildAppStage(stage.BootingStage): await pipeline_mgr.initialize() ap.pipeline_mgr = pipeline_mgr + rag_mgr_inst = rag_mgr.RAGManager(ap) + await rag_mgr_inst.initialize() + ap.rag_mgr = rag_mgr_inst + + # 初始化向量数据库管理器 + vectordb_mgr_inst = vectordb_mgr.VectorDBManager(ap) + await vectordb_mgr_inst.initialize() + ap.vector_db_mgr = vectordb_mgr_inst + http_ctrl = http_controller.HTTPController(ap) await http_ctrl.initialize() ap.http_ctrl = http_ctrl @@ -95,8 +107,11 @@ class BuildAppStage(stage.BootingStage): user_service_inst = user_service.UserService(ap) ap.user_service = user_service_inst - model_service_inst = model_service.ModelsService(ap) - ap.model_service = model_service_inst + llm_model_service_inst = model_service.LLMModelsService(ap) + ap.llm_model_service = llm_model_service_inst + + embedding_models_service_inst = model_service.EmbeddingModelsService(ap) + ap.embedding_models_service = embedding_models_service_inst pipeline_service_inst = pipeline_service.PipelineService(ap) ap.pipeline_service = pipeline_service_inst @@ -104,5 +119,8 @@ class BuildAppStage(stage.BootingStage): bot_service_inst = bot_service.BotService(ap) ap.bot_service = bot_service_inst + knowledge_service_inst = knowledge_service.KnowledgeService(ap) + ap.knowledge_service = knowledge_service_inst + ctrl = controller.Controller(ap) ap.ctrl = ctrl diff --git a/pkg/entity/persistence/model.py b/pkg/entity/persistence/model.py index 6cf93ec7..e9a104c4 100644 --- a/pkg/entity/persistence/model.py +++ b/pkg/entity/persistence/model.py @@ -23,3 +23,24 @@ class LLMModel(Base): server_default=sqlalchemy.func.now(), onupdate=sqlalchemy.func.now(), ) + + +class EmbeddingModel(Base): + """Embedding 模型""" + + __tablename__ = 'embedding_models' + + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) + requester_config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) + extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/pkg/entity/persistence/pipeline.py b/pkg/entity/persistence/pipeline.py index 8f28b242..3a21dbf2 100644 --- a/pkg/entity/persistence/pipeline.py +++ b/pkg/entity/persistence/pipeline.py @@ -20,7 +20,6 @@ class LegacyPipeline(Base): ) for_version = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) is_default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False) - stages = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) @@ -43,3 +42,4 @@ class PipelineRunRecord(Base): started_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False) finished_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False) result = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) + knowledge_base_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py new file mode 100644 index 00000000..0ff93d28 --- /dev/null +++ b/pkg/entity/persistence/rag.py @@ -0,0 +1,50 @@ +import sqlalchemy +from .base import Base + +# Base = declarative_base() +# DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db') +# print("Using database URL:", DATABASE_URL) + + +# engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False}) + +# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# def create_db_and_tables(): +# """Creates all database tables defined in the Base.""" +# Base.metadata.create_all(bind=engine) +# print('Database tables created or already exist.') + + +class KnowledgeBase(Base): + __tablename__ = 'knowledge_bases' + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(sqlalchemy.String, index=True) + description = sqlalchemy.Column(sqlalchemy.Text) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now()) + embedding_model_uuid = sqlalchemy.Column(sqlalchemy.String, default='') + top_k = sqlalchemy.Column(sqlalchemy.Integer, default=5) + + +class File(Base): + __tablename__ = 'knowledge_base_files' + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + kb_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + file_name = sqlalchemy.Column(sqlalchemy.String) + extension = sqlalchemy.Column(sqlalchemy.String) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now()) + status = sqlalchemy.Column(sqlalchemy.String, default='pending') # pending, processing, completed, failed + + +class Chunk(Base): + __tablename__ = 'knowledge_base_chunks' + uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) + file_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) + text = sqlalchemy.Column(sqlalchemy.Text) + + +# class Vector(Base): +# __tablename__ = 'knowledge_base_vectors' +# uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) +# chunk_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) +# embedding = sqlalchemy.Column(sqlalchemy.LargeBinary) diff --git a/pkg/entity/persistence/vector.py b/pkg/entity/persistence/vector.py new file mode 100644 index 00000000..465125f5 --- /dev/null +++ b/pkg/entity/persistence/vector.py @@ -0,0 +1,13 @@ +from sqlalchemy import Column, Integer, ForeignKey, LargeBinary +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class Vector(Base): + __tablename__ = 'vectors' + id = Column(Integer, primary_key=True, index=True) + chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) + embedding = Column(LargeBinary) # Store embeddings as binary + + chunk = relationship('Chunk', back_populates='vector') diff --git a/pkg/entity/rag/__init__.py b/pkg/entity/rag/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/entity/rag/retriever.py b/pkg/entity/rag/retriever.py new file mode 100644 index 00000000..becaf8db --- /dev/null +++ b/pkg/entity/rag/retriever.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import pydantic + +from typing import Any + + +class RetrieveResultEntry(pydantic.BaseModel): + id: str + + metadata: dict[str, Any] + + distance: float diff --git a/pkg/persistence/mgr.py b/pkg/persistence/mgr.py index 9d2bab7b..3aa21ad2 100644 --- a/pkg/persistence/mgr.py +++ b/pkg/persistence/mgr.py @@ -79,7 +79,7 @@ class PersistenceManager: 'stages': pipeline_service.default_stage_order, 'is_default': True, 'name': 'ChatPipeline', - 'description': 'Default pipeline provided, your new bots will be automatically bound to this pipeline | 默认提供的流水线,您配置的机器人将自动绑定到此流水线', + 'description': 'Default pipeline, new bots will be bound to this pipeline | 默认提供的流水线,您配置的机器人将自动绑定到此流水线', 'config': pipeline_config, } diff --git a/pkg/persistence/migrations/dbm004_rag_kb_uuid.py b/pkg/persistence/migrations/dbm004_rag_kb_uuid.py new file mode 100644 index 00000000..b45cfa78 --- /dev/null +++ b/pkg/persistence/migrations/dbm004_rag_kb_uuid.py @@ -0,0 +1,38 @@ +from .. import migration + +import sqlalchemy + +from ...entity.persistence import pipeline as persistence_pipeline + + +@migration.migration_class(4) +class DBMigrateRAGKBUUID(migration.DBMigration): + """RAG知识库UUID""" + + async def upgrade(self): + """升级""" + # read all pipelines + pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) + + for pipeline in pipelines: + serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) + + config = serialized_pipeline['config'] + + if 'knowledge-base' not in config['ai']['local-agent']: + config['ai']['local-agent']['knowledge-base'] = '' + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_pipeline.LegacyPipeline) + .where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid']) + .values( + { + 'config': config, + 'for_version': self.ap.ver_mgr.get_current_version(), + } + ) + ) + + async def downgrade(self): + """降级""" + pass diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index b61e34ad..77df09dc 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -144,23 +144,27 @@ class RuntimePipeline: result = await result if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 - self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {result}') + self.ap.logger.debug( + f'Stage {stage_container.inst_name} processed query {query.query_id} res {result.result_type}' + ) await self._check_output(query, result) if result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}') + self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}') break elif result.result_type == pipeline_entities.ResultType.CONTINUE: query = result.new_query elif isinstance(result, typing.AsyncGenerator): # 生成器 - self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} gen') + self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query.query_id} gen') async for sub_result in result: - self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {sub_result}') + self.ap.logger.debug( + f'Stage {stage_container.inst_name} processed query {query.query_id} res {sub_result.result_type}' + ) await self._check_output(query, sub_result) if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}') + self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}') break elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: query = sub_result.new_query @@ -192,7 +196,7 @@ class RuntimePipeline: if event_ctx.is_prevented_default(): return - self.ap.logger.debug(f'Processing query {query}') + self.ap.logger.debug(f'Processing query {query.query_id}') await self._execute_from_stage(0, query) except Exception as e: @@ -200,7 +204,7 @@ class RuntimePipeline: self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}') self.ap.logger.error(f'Traceback: {traceback.format_exc()}') finally: - self.ap.logger.debug(f'Query {query} processed') + self.ap.logger.debug(f'Query {query.query_id} processed') class PipelineManager: diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index bfa0924d..1aada6b3 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -80,14 +80,15 @@ class PreProcessor(stage.PipelineStage): if me.type == 'image_url': msg.content.remove(me) - content_list = [] + content_list: list[llm_entities.ContentElement] = [] plain_text = '' qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message') + # tidy the content_list + # combine all text content into one, and put it in the first position for me in query.message_chain: if isinstance(me, platform_message.Plain): - content_list.append(llm_entities.ContentElement.from_text(me.text)) plain_text += me.text elif isinstance(me, platform_message.Image): if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( @@ -106,6 +107,8 @@ class PreProcessor(stage.PipelineStage): if msg.base64 is not None: content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64)) + content_list.insert(0, llm_entities.ContentElement.from_text(plain_text)) + query.variables['user_message_text'] = plain_text query.user_message = llm_entities.Message(role='user', content=content_list) diff --git a/pkg/platform/logger.py b/pkg/platform/logger.py index 340baa07..a2ea2e25 100644 --- a/pkg/platform/logger.py +++ b/pkg/platform/logger.py @@ -119,7 +119,7 @@ class EventLogger: async def _truncate_logs(self): if len(self.logs) > MAX_LOG_COUNT: for i in range(DELETE_COUNT_PER_TIME): - for image_key in self.logs[i].images: + for image_key in self.logs[i].images: # type: ignore await self.ap.storage_mgr.storage_provider.delete(image_key) self.logs = self.logs[DELETE_COUNT_PER_TIME:] diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index 929636a5..c279e714 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -654,10 +654,10 @@ class DiscordMessageConverter(adapter.MessageConverter): # 确保路径没有空字节 clean_path = ele.path.replace('\x00', '') clean_path = os.path.abspath(clean_path) - + if not os.path.exists(clean_path): continue # 跳过不存在的文件 - + try: with open(clean_path, 'rb') as f: image_bytes = f.read() @@ -677,12 +677,13 @@ class DiscordMessageConverter(adapter.MessageConverter): filename = f'{uuid.uuid4()}.webp' # 默认保持PNG except Exception as e: - print(f"Error reading image file {clean_path}: {e}") + print(f'Error reading image file {clean_path}: {e}') continue # 跳过读取失败的文件 if image_bytes: # 使用BytesIO创建文件对象,避免路径问题 import io + image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename)) elif isinstance(ele, platform_message.Plain): text_string += ele.text @@ -1003,25 +1004,25 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): msg_to_send, image_files = await self.message_converter.yiri2target(message) - + try: # 获取频道对象 channel = self.bot.get_channel(int(target_id)) if channel is None: # 如果本地缓存中没有,尝试从API获取 channel = await self.bot.fetch_channel(int(target_id)) - + args = { 'content': msg_to_send, } - + if len(image_files) > 0: args['files'] = image_files - + await channel.send(**args) - + except Exception as e: - await self.logger.error(f"Discord send_message failed: {e}") + await self.logger.error(f'Discord send_message failed: {e}') raise e async def reply_message( diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index d1116362..f8faf522 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -378,15 +378,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter): if 'im.message.receive_v1' == type: try: event = await self.event_converter.target2yiri(p2v1, self.api_client) - except Exception as e: - await self.logger.error(f"Error in lark callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) return {'code': 200, 'message': 'ok'} - except Exception as e: - await self.logger.error(f"Error in lark callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') return {'code': 500, 'message': 'error'} async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1): diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 389a2db1..16ad54db 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): content=content_list, ) nakuru_forward_node_list.append(nakuru_forward_node) - except Exception as e: + except Exception: import traceback + traceback.print_exc() nakuru_msg_list.append(nakuru_forward_node_list) @@ -276,7 +277,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): # 注册监听器 self.bot.receiver(source_cls.__name__)(listener_wrapper) except Exception as e: - self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}") + self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}') raise e def unregister_listener( diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py index 030db56d..3fc1e393 100644 --- a/pkg/platform/sources/officialaccount.py +++ b/pkg/platform/sources/officialaccount.py @@ -125,8 +125,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in officialaccount callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in officialaccount callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/platform/sources/qqofficial.py b/pkg/platform/sources/qqofficial.py index c61afea4..63ab531f 100644 --- a/pkg/platform/sources/qqofficial.py +++ b/pkg/platform/sources/qqofficial.py @@ -154,10 +154,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): raise ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员') self.bot = QQOfficialClient( - app_id=config['appid'], - secret=config['secret'], - token=config['token'], - logger=self.logger + app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger ) async def reply_message( @@ -224,8 +221,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = 'justbot' try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in qqofficial callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in qqofficial callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message) diff --git a/pkg/platform/sources/slack.py b/pkg/platform/sources/slack.py index 6dfcff59..1bd5aa2d 100644 --- a/pkg/platform/sources/slack.py +++ b/pkg/platform/sources/slack.py @@ -104,7 +104,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter): if missing_keys: raise ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员') - self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger) + self.bot = SlackClient( + bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger + ) async def reply_message( self, @@ -139,8 +141,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = 'SlackBot' try: return await callback(await self.event_converter.target2yiri(event, self.bot), self) - except Exception as e: - await self.logger.error(f"Error in slack callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in slack callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('im')(on_message) diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index 266d994e..c2fcc22e 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -160,8 +160,8 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): try: lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id) await self.listeners[type(lb_event)](lb_event, self) - except Exception as e: - await self.logger.error(f"Error in telegram callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}') self.application = ApplicationBuilder().token(self.config['token']).build() self.bot = self.application.bot diff --git a/pkg/platform/sources/wechatpad.py b/pkg/platform/sources/wechatpad.py index ac1be16b..a24287cb 100644 --- a/pkg/platform/sources/wechatpad.py +++ b/pkg/platform/sources/wechatpad.py @@ -1,5 +1,4 @@ import requests -import websockets import websocket import json import time @@ -10,32 +9,25 @@ from libs.wechatpad_api.client import WeChatPadClient import typing import asyncio import traceback -import time import re import base64 -import uuid -import json -import os import copy -import datetime import threading import quart -import aiohttp from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities -from ...utils import image from ..logger import EventLogger import xml.etree.ElementTree as ET -from typing import Optional, List, Tuple +from typing import Optional, Tuple from functools import partial import logging + class WeChatPadMessageConverter(adapter.MessageConverter): def __init__(self, config: dict, logger: logging.Logger): @@ -44,19 +36,14 @@ class WeChatPadMessageConverter(adapter.MessageConverter): self.logger = logger @staticmethod - async def yiri2target( - message_chain: platform_message.MessageChain - ) -> list[dict]: + async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]: content_list = [] - current_file_path = os.path.abspath(__file__) - - for component in message_chain: if isinstance(component, platform_message.At): - content_list.append({"type": "at", "target": component.target}) + content_list.append({'type': 'at', 'target': component.target}) elif isinstance(component, platform_message.Plain): - content_list.append({"type": "text", "content": component.text}) + content_list.append({'type': 'text', 'content': component.text}) elif isinstance(component, platform_message.Image): if component.url: async with httpx.AsyncClient() as client: @@ -68,15 +55,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter): else: raise Exception('获取文件失败') # pass - content_list.append({"type": "image", "image": base64_str}) + content_list.append({'type': 'image', 'image': base64_str}) elif component.base64: - content_list.append({"type": "image", "image": component.base64}) + content_list.append({'type': 'image', 'image': component.base64}) elif isinstance(component, platform_message.WeChatEmoji): content_list.append( - {'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size}) + {'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size} + ) elif isinstance(component, platform_message.Voice): - content_list.append({"type": "voice", "data": component.url, "duration": component.length, "forma": 0}) + content_list.append({'type': 'voice', 'data': component.url, 'duration': component.length, 'forma': 0}) elif isinstance(component, platform_message.WeChatAppMsg): content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg}) elif isinstance(component, platform_message.Forward): @@ -86,7 +74,6 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return content_list - async def target2yiri( self, message: dict, @@ -97,11 +84,12 @@ class WeChatPadMessageConverter(adapter.MessageConverter): message_list = [] bot_wxid = self.config['wxid'] ats_bot = False # 是否被@ - content = message["content"]["str"] + content = message['content']['str'] content_no_preifx = content # 群消息则去掉前缀 is_group_message = self._is_group_message(message) if is_group_message: ats_bot = self._ats_bot(message, bot_account_id) + self.logger.info(f"ats_bot: {ats_bot}; bot_account_id: {bot_account_id}; bot_wxid: {bot_wxid}") if "@所有人" in content: message_list.append(platform_message.AtAll()) @@ -116,7 +104,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): content_no_preifx, _ = self._extract_content_and_sender(content) - msg_type = message["msg_type"] + msg_type = message['msg_type'] # 映射消息类型到处理器方法 handler_map = { @@ -138,11 +126,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_list) - async def _handler_text( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理文本消息 (msg_type=1)""" if message and self._is_group_message(message): pattern = r'@\S{1,20}' @@ -150,16 +134,12 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain([platform_message.Plain(content_no_preifx)]) - async def _handler_image( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理图像消息 (msg_type=3)""" try: image_xml = content_no_preifx if not image_xml: - return platform_message.MessageChain([platform_message.Unknown("[图片内容为空]")]) + return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')]) root = ET.fromstring(image_xml) # 提取img标签的属性 @@ -169,28 +149,22 @@ class WeChatPadMessageConverter(adapter.MessageConverter): cdnthumburl = img_tag.get('cdnthumburl') # cdnmidimgurl = img_tag.get('cdnmidimgurl') - image_data = self.bot.cdn_download(aeskey=aeskey, file_type=1, file_url=cdnthumburl) - if image_data["Data"]['FileData'] == '': + if image_data['Data']['FileData'] == '': image_data = self.bot.cdn_download(aeskey=aeskey, file_type=2, file_url=cdnthumburl) - base64_str = image_data["Data"]['FileData'] + base64_str = image_data['Data']['FileData'] # self.logger.info(f"data:image/png;base64,{base64_str}") - elements = [ - platform_message.Image(base64=f"data:image/png;base64,{base64_str}"), + platform_message.Image(base64=f'data:image/png;base64,{base64_str}'), # platform_message.WeChatForwardImage(xml_data=image_xml) # 微信消息转发 ] return platform_message.MessageChain(elements) except Exception as e: - self.logger.error(f"处理图片失败: {str(e)}") - return platform_message.MessageChain([platform_message.Unknown("[图片处理失败]")]) + self.logger.error(f'处理图片失败: {str(e)}') + return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')]) - async def _handler_voice( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理语音消息 (msg_type=34)""" message_List = [] try: @@ -206,39 +180,33 @@ class WeChatPadMessageConverter(adapter.MessageConverter): bufid = voicemsg.get('bufid') length = voicemsg.get('voicelength') voice_data = self.bot.get_msg_voice(buf_id=str(bufid), length=int(length), msgid=str(new_msg_id)) - audio_base64 = voice_data["Data"]['Base64'] + audio_base64 = voice_data['Data']['Base64'] # 验证语音数据有效性 if not audio_base64: - message_List.append(platform_message.Unknown(text="[语音内容为空]")) + message_List.append(platform_message.Unknown(text='[语音内容为空]')) return platform_message.MessageChain(message_List) # 转换为平台支持的语音格式(如 Silk 格式) - voice_element = platform_message.Voice( - base64=f"data:audio/silk;base64,{audio_base64}" - ) + voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}') message_List.append(voice_element) except KeyError as e: - self.logger.error(f"语音数据字段缺失: {str(e)}") - message_List.append(platform_message.Unknown(text="[语音数据解析失败]")) + self.logger.error(f'语音数据字段缺失: {str(e)}') + message_List.append(platform_message.Unknown(text='[语音数据解析失败]')) except Exception as e: - self.logger.error(f"处理语音消息异常: {str(e)}") - message_List.append(platform_message.Unknown(text="[语音处理失败]")) + self.logger.error(f'处理语音消息异常: {str(e)}') + message_List.append(platform_message.Unknown(text='[语音处理失败]')) return platform_message.MessageChain(message_List) - async def _handler_compound( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理复合消息 (msg_type=49),根据子类型分派""" try: xml_data = ET.fromstring(content_no_preifx) appmsg_data = xml_data.find('.//appmsg') if appmsg_data: - data_type = appmsg_data.findtext('.//type', "") + data_type = appmsg_data.findtext('.//type', '') # 二次分派处理器 sub_handler_map = { '57': self._handler_compound_quote, @@ -247,9 +215,9 @@ class WeChatPadMessageConverter(adapter.MessageConverter): '74': self._handler_compound_file, '33': self._handler_compound_mini_program, '36': self._handler_compound_mini_program, - '2000': partial(self._handler_compound_unsupported, text="[转账消息]"), - '2001': partial(self._handler_compound_unsupported, text="[红包消息]"), - '51': partial(self._handler_compound_unsupported, text="[视频号消息]"), + '2000': partial(self._handler_compound_unsupported, text='[转账消息]'), + '2001': partial(self._handler_compound_unsupported, text='[红包消息]'), + '51': partial(self._handler_compound_unsupported, text='[视频号消息]'), } handler = sub_handler_map.get(data_type, self._handler_compound_unsupported) @@ -260,56 +228,54 @@ class WeChatPadMessageConverter(adapter.MessageConverter): else: return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) except Exception as e: - self.logger.error(f"解析复合消息失败: {str(e)}") + self.logger.error(f'解析复合消息失败: {str(e)}') return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) async def _handler_compound_quote( - self, - message: Optional[dict], - xml_data: ET.Element + self, message: Optional[dict], xml_data: ET.Element ) -> platform_message.MessageChain: """处理引用消息 (data_type=57)""" message_list = [] -# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) + # self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) appmsg_data = xml_data.find('.//appmsg') - quote_data = "" # 引用原文 + quote_data = '' # 引用原文 quote_id = None # 引用消息的原发送者 tousername = None # 接收方: 所属微信的wxid - user_data = "" # 用户消息 + user_data = '' # 用户消息 sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member # 引用消息转发 if appmsg_data: - user_data = appmsg_data.findtext('.//title') or "" + user_data = appmsg_data.findtext('.//title') or '' quote_data = appmsg_data.find('.//refermsg').findtext('.//content') quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') - message_list.append( - platform_message.WeChatAppMsg( - app_msg=ET.tostring(appmsg_data, encoding='unicode')) - ) + message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode'))) if message: - tousername = message['to_user_name']["str"] - + tousername = message['to_user_name']['str'] + + _ = quote_id + _ = tousername + if quote_data: quote_data_message_list = platform_message.MessageChain() # 文本消息 try: - if "" not in quote_data: + if '' not in quote_data: quote_data_message_list.append(platform_message.Plain(quote_data)) else: # 引用消息展开 quote_data_xml = ET.fromstring(quote_data) - if quote_data_xml.find("img"): + if quote_data_xml.find('img'): quote_data_message_list.extend(await self._handler_image(None, quote_data)) - elif quote_data_xml.find("voicemsg"): + elif quote_data_xml.find('voicemsg'): quote_data_message_list.extend(await self._handler_voice(None, quote_data)) - elif quote_data_xml.find("videomsg"): + elif quote_data_xml.find('videomsg'): quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理 else: # appmsg quote_data_message_list.extend(await self._handler_compound(None, quote_data)) except Exception as e: - self.logger.error(f"处理引用消息异常 expcetion:{e}") + self.logger.error(f'处理引用消息异常 expcetion:{e}') quote_data_message_list.append(platform_message.Plain(quote_data)) message_list.append( platform_message.Quote( @@ -324,15 +290,11 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_list) - async def _handler_compound_file( - self, - message: dict, - xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理文件消息 (data_type=6)""" file_data = xml_data.find('.//appmsg') - if file_data.findtext('.//type', "") == "74": + if file_data.findtext('.//type', '') == '74': return None else: @@ -355,22 +317,21 @@ class WeChatPadMessageConverter(adapter.MessageConverter): file_data = self.bot.cdn_download(aeskey=aeskey, file_type=5, file_url=cdnthumburl) - file_base64 = file_data["Data"]['FileData'] + file_base64 = file_data['Data']['FileData'] # print(file_data) - file_size = file_data["Data"]['TotalSize'] + file_size = file_data['Data']['TotalSize'] # print(file_base64) - return platform_message.MessageChain([ - platform_message.WeChatFile(file_id=file_id, file_name=file_name, file_size=file_size, - file_base64=file_base64), - platform_message.WeChatForwardFile(xml_data=xml_data_str) - ]) + return platform_message.MessageChain( + [ + platform_message.WeChatFile( + file_id=file_id, file_name=file_name, file_size=file_size, file_base64=file_base64 + ), + platform_message.WeChatForwardFile(xml_data=xml_data_str), + ] + ) - async def _handler_compound_link( - self, - message: dict, - xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理链接消息(如公众号文章、外部网页)""" message_list = [] try: @@ -383,56 +344,38 @@ class WeChatPadMessageConverter(adapter.MessageConverter): link_title=appmsg.findtext('title', ''), link_desc=appmsg.findtext('des', ''), link_url=appmsg.findtext('url', ''), - link_thumb_url=appmsg.findtext("thumburl", '') # 这个字段拿不到 + link_thumb_url=appmsg.findtext('thumburl', ''), # 这个字段拿不到 ) ) # 还没有发链接的接口, 暂时还需要自己构造appmsg, 先用WeChatAppMsg。 - message_list.append( - platform_message.WeChatAppMsg( - app_msg=ET.tostring(appmsg, encoding='unicode') - ) - ) + message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg, encoding='unicode'))) except Exception as e: - self.logger.error(f"解析链接消息失败: {str(e)}") + self.logger.error(f'解析链接消息失败: {str(e)}') return platform_message.MessageChain(message_list) async def _handler_compound_mini_program( - self, - message: dict, - xml_data: ET.Element + self, message: dict, xml_data: ET.Element ) -> platform_message.MessageChain: """处理小程序消息(如小程序卡片、服务通知)""" xml_data_str = ET.tostring(xml_data, encoding='unicode') - return platform_message.MessageChain([ - platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str) - ]) + return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)]) - async def _handler_default( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理未知消息类型""" if message: - msg_type = message["msg_type"] + msg_type = message['msg_type'] else: - msg_type = "" - return platform_message.MessageChain([ - platform_message.Unknown(text=f"[未知消息类型 msg_type:{msg_type}]") - ]) + msg_type = '' + return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')]) def _handler_compound_unsupported( - self, - message: dict, - xml_data: str, - text: Optional[str] = None + self, message: dict, xml_data: str, text: Optional[str] = None ) -> platform_message.MessageChain: """处理未支持复合消息类型(msg_type=49)子类型""" if not text: - text = f"[xml_data={xml_data}]" + text = f'[xml_data={xml_data}]' content_list = [] - content_list.append( - platform_message.Unknown(text=f"[处理未支持复合消息类型[msg_type=49]|{text}")) + content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}')) return platform_message.MessageChain(content_list) @@ -441,7 +384,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): ats_bot = False try: to_user_name = message['to_user_name']['str'] # 接收方: 所属微信的wxid - raw_content = message["content"]["str"] # 原始消息内容 + raw_content = message['content']['str'] # 原始消息内容 content_no_prefix, _ = self._extract_content_and_sender(raw_content) # 直接艾特机器人(这个有bug,当被引用的消息里面有@bot,会套娃 # ats_bot = ats_bot or (f"@{bot_account_id}" in content_no_prefix) @@ -452,7 +395,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): msg_source = message.get('msg_source', '') or '' if len(msg_source) > 0: msg_source_data = ET.fromstring(msg_source) - at_user_list = msg_source_data.findtext("atuserlist") or "" + at_user_list = msg_source_data.findtext('atuserlist') or '' ats_bot = ats_bot or (to_user_name in at_user_list) # 引用bot if message.get('msg_type', 0) == 49: @@ -463,7 +406,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者 ats_bot = ats_bot or (quote_id == tousername) except Exception as e: - self.logger.error(f"_ats_bot got except: {e}") + self.logger.error(f'_ats_bot got except: {e}') finally: return ats_bot @@ -489,21 +432,21 @@ class WeChatPadMessageConverter(adapter.MessageConverter): try: # 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉 # add: 有些用户的wxid不是上述格式。换成user_name: - regex = re.compile(r"^[a-zA-Z0-9_\-]{5,20}:") - line_split = raw_content.split("\n") + regex = re.compile(r'^[a-zA-Z0-9_\-]{5,20}:') + line_split = raw_content.split('\n') if len(line_split) > 0 and regex.match(line_split[0]): - raw_content = "\n".join(line_split[1:]) - sender_id = line_split[0].strip(":") + raw_content = '\n'.join(line_split[1:]) + sender_id = line_split[0].strip(':') return raw_content, sender_id except Exception as e: - self.logger.error(f"_extract_content_and_sender got except: {e}") + self.logger.error(f'_extract_content_and_sender got except: {e}') finally: return raw_content, None # 是否是群消息 def _is_group_message(self, message: dict) -> bool: from_user_name = message['from_user_name']['str'] - return from_user_name.endswith("@chatroom") + return from_user_name.endswith('@chatroom') class WeChatPadEventConverter(adapter.EventConverter): @@ -514,9 +457,7 @@ class WeChatPadEventConverter(adapter.EventConverter): self.logger = logger @staticmethod - async def yiri2target( - event: platform_events.MessageEvent - ) -> dict: + async def yiri2target(event: platform_events.MessageEvent) -> dict: pass async def target2yiri( @@ -526,10 +467,12 @@ class WeChatPadEventConverter(adapter.EventConverter): ) -> platform_events.MessageEvent: # 排除公众号以及微信团队消息 - if event['from_user_name']['str'].startswith('gh_') \ - or event['from_user_name']['str']=='weixin'\ - or event['from_user_name']['str'] == "newsapp"\ - or event['from_user_name']['str'] == self.config["wxid"]: + if ( + event['from_user_name']['str'].startswith('gh_') + or event['from_user_name']['str'] == 'weixin' + or event['from_user_name']['str'] == 'newsapp' + or event['from_user_name']['str'] == self.config['wxid'] + ): return None message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id) @@ -538,7 +481,7 @@ class WeChatPadEventConverter(adapter.EventConverter): if '@chatroom' in event['from_user_name']['str']: # 找出开头的 wxid_ 字符串,以:结尾 - sender_wxid = event['content']['str'].split(":")[0] + sender_wxid = event['content']['str'].split(':')[0] return platform_events.GroupMessage( sender=platform_entities.GroupMember( @@ -550,13 +493,13 @@ class WeChatPadEventConverter(adapter.EventConverter): name=event['from_user_name']['str'], permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, ), message_chain=message_chain, - time=event["create_time"], + time=event['create_time'], source_platform_object=event, ) else: @@ -567,13 +510,13 @@ class WeChatPadEventConverter(adapter.EventConverter): remark='', ), message_chain=message_chain, - time=event["create_time"], + time=event['create_time'], source_platform_object=event, ) class WeChatPadAdapter(adapter.MessagePlatformAdapter): - name: str = "WeChatPad" # 定义适配器名称 + name: str = 'WeChatPad' # 定义适配器名称 bot: WeChatPadClient quart_app: quart.Quart @@ -606,27 +549,21 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # self.ap.logger.debug(f"Gewechat callback event: {data}") # print(data) - try: event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) - except Exception as e: - await self.logger.error(f"Error in wechatpad callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wechatpad callback: {traceback.format_exc()}') if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) return 'ok' - - async def _handle_message( - self, - message: platform_message.MessageChain, - target_id: str - ): + async def _handle_message(self, message: platform_message.MessageChain, target_id: str): """统一消息处理核心逻辑""" content_list = await self.message_converter.yiri2target(message) # print(content_list) - at_targets = [item["target"] for item in content_list if item["type"] == "at"] + at_targets = [item['target'] for item in content_list if item['type'] == 'at'] # print(at_targets) # 处理@逻辑 at_targets = at_targets or [] @@ -634,7 +571,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if at_targets: member_info = self.bot.get_chatroom_member_detail( target_id, - )["Data"]["member_data"]["chatroom_member_list"] + )['Data']['member_data']['chatroom_member_list'] # 处理消息组件 for msg in content_list: @@ -642,63 +579,51 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if msg['type'] == 'text' and at_targets: at_nick_name_list = [] for member in member_info: - if member["user_name"] in at_targets: + if member['user_name'] in at_targets: at_nick_name_list.append(f'@{member["nick_name"]}') msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}' # 统一消息派发 handler_map = { 'text': lambda msg: self.bot.send_text_message( - to_wxid=target_id, - message=msg['content'], - ats=at_targets + to_wxid=target_id, message=msg['content'], ats=at_targets ), 'image': lambda msg: self.bot.send_image_message( - to_wxid=target_id, - img_url=msg["image"], - ats = at_targets + to_wxid=target_id, img_url=msg['image'], ats=at_targets ), 'WeChatEmoji': lambda msg: self.bot.send_emoji_message( - to_wxid=target_id, - emoji_md5=msg['emoji_md5'], - emoji_size=msg['emoji_size'] + to_wxid=target_id, emoji_md5=msg['emoji_md5'], emoji_size=msg['emoji_size'] ), - 'voice': lambda msg: self.bot.send_voice_message( to_wxid=target_id, voice_data=msg['data'], - voice_duration=msg["duration"], - voice_forma=msg["forma"], + voice_duration=msg['duration'], + voice_forma=msg['forma'], ), 'WeChatAppMsg': lambda msg: self.bot.send_app_message( to_wxid=target_id, app_message=msg['app_msg'], type=0, ), - 'at': lambda msg: None + 'at': lambda msg: None, } if handler := handler_map.get(msg['type']): handler(msg) # self.ap.logger.warning(f"未处理的消息类型: {ret}") else: - self.ap.logger.warning(f"未处理的消息类型: {msg['type']}") + self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') continue - async def send_message( - self, - target_type: str, - target_id: str, - message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): """主动发送消息""" return await self._handle_message(message, target_id) async def reply_message( - self, - message_source: platform_events.MessageEvent, - message: platform_message.MessageChain, - quote_origin: bool = False + self, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, + quote_origin: bool = False, ): """回复消息""" if message_source.source_platform_object: @@ -709,58 +634,49 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): pass def register_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners[event_type] = callback def unregister_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): pass async def run_async(self): - - if not self.config["admin_key"] and not self.config["token"]: - raise RuntimeError("无wechatpad管理密匙,请填入配置文件后重启") + if not self.config['admin_key'] and not self.config['token']: + raise RuntimeError('无wechatpad管理密匙,请填入配置文件后重启') else: - if self.config["token"]: - self.bot = WeChatPadClient( - self.config['wechatpad_url'], - self.config["token"] - ) + if self.config['token']: + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) data = self.bot.get_login_status() self.ap.logger.info(data) - if data["Code"] == 300 and data["Text"] == "你已退出微信": + if data['Code'] == 300 and data['Text'] == '你已退出微信': response = requests.post( - f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", - json={"Count": 1, "Days": 365} + f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', + json={'Count': 1, 'Days': 365}, ) if response.status_code != 200: - raise Exception(f"获取token失败: {response.text}") - self.config["token"] = response.json()["Data"][0] + raise Exception(f'获取token失败: {response.text}') + self.config['token'] = response.json()['Data'][0] - elif not self.config["token"]: + elif not self.config['token']: response = requests.post( - f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", - json={"Count": 1, "Days": 365} + f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', + json={'Count': 1, 'Days': 365}, ) if response.status_code != 200: - raise Exception(f"获取token失败: {response.text}") - self.config["token"] = response.json()["Data"][0] + raise Exception(f'获取token失败: {response.text}') + self.config['token'] = response.json()['Data'][0] - self.bot = WeChatPadClient( - self.config['wechatpad_url'], - self.config["token"], - logger=self.logger - ) - self.ap.logger.info(self.config["token"]) + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger) + self.ap.logger.info(self.config['token']) thread_1 = threading.Event() - def wechat_login_process(): # 不登录,这些先注释掉,避免登陆态尝试拉qrcode。 # login_data =self.bot.get_login_qr() @@ -768,67 +684,54 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # url = login_data['Data']["QrCodeUrl"] # self.ap.logger.info(login_data) - - profile =self.bot.get_profile() + profile = self.bot.get_profile() self.ap.logger.info(profile) - self.bot_account_id = profile["Data"]["userInfo"]["nickName"]["str"] - self.config["wxid"] = profile["Data"]["userInfo"]["userName"]["str"] + self.bot_account_id = profile['Data']['userInfo']['nickName']['str'] + self.config['wxid'] = profile['Data']['userInfo']['userName']['str'] thread_1.set() - # asyncio.create_task(wechat_login_process) threading.Thread(target=wechat_login_process).start() def connect_websocket_sync() -> None: - thread_1.wait() - uri = f"{self.config['wechatpad_ws']}/GetSyncMsg?key={self.config['token']}" - self.ap.logger.info(f"Connecting to WebSocket: {uri}") + uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}' + self.ap.logger.info(f'Connecting to WebSocket: {uri}') + def on_message(ws, message): try: data = json.loads(message) - self.ap.logger.debug(f"Received message: {data}") + self.ap.logger.debug(f'Received message: {data}') # 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法 asyncio.run(self.ws_message(data)) except json.JSONDecodeError: - self.ap.logger.error(f"Non-JSON message: {message[:100]}...") + self.ap.logger.error(f'Non-JSON message: {message[:100]}...') def on_error(ws, error): - self.ap.logger.error(f"WebSocket error: {str(error)[:200]}") + self.ap.logger.error(f'WebSocket error: {str(error)[:200]}') def on_close(ws, close_status_code, close_msg): - self.ap.logger.info("WebSocket closed, reconnecting...") + self.ap.logger.info('WebSocket closed, reconnecting...') time.sleep(5) connect_websocket_sync() # 自动重连 def on_open(ws): - self.ap.logger.info("WebSocket connected successfully!") + self.ap.logger.info('WebSocket connected successfully!') ws = websocket.WebSocketApp( - uri, - on_message=on_message, - on_error=on_error, - on_close=on_close, - on_open=on_open - ) - ws.run_forever( - ping_interval=60, - ping_timeout=20 + uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open ) + ws.run_forever(ping_interval=60, ping_timeout=20) # 直接调用同步版本(会阻塞) # connect_websocket_sync() # 这行代码会在WebSocket连接断开后才会执行 # self.ap.logger.info("WebSocket client thread started") - thread = threading.Thread( - target=connect_websocket_sync, - name="WebSocketClientThread", - daemon=True - ) + thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True) thread.start() - self.ap.logger.info("WebSocket client thread started") + self.ap.logger.info('WebSocket client thread started') async def kill(self) -> bool: pass diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index f1cc677e..7be05a85 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -157,7 +157,7 @@ class WecomAdapter(adapter.MessagePlatformAdapter): token=config['token'], EncodingAESKey=config['EncodingAESKey'], contacts_secret=config['contacts_secret'], - logger=self.logger + logger=self.logger, ) async def reply_message( @@ -201,8 +201,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in wecom callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/platform/sources/wecomcs.py b/pkg/platform/sources/wecomcs.py index aab8d394..da84ac6d 100644 --- a/pkg/platform/sources/wecomcs.py +++ b/pkg/platform/sources/wecomcs.py @@ -145,7 +145,7 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): secret=config['secret'], token=config['token'], EncodingAESKey=config['EncodingAESKey'], - logger=self.logger + logger=self.logger, ) async def reply_message( @@ -178,8 +178,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in wecomcs callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/provider/modelmgr/entities.py b/pkg/provider/modelmgr/entities.py index cf856894..7bc02a32 100644 --- a/pkg/provider/modelmgr/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel): token_mgr: token.TokenManager - requester: requester.LLMAPIRequester + requester: requester.ProviderAPIRequester tool_call_supported: typing.Optional[bool] = False diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index b15e53a9..2c92eacc 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -18,7 +18,7 @@ class ModelManager: model_list: list[entities.LLMModelInfo] # deprecated - requesters: dict[str, requester.LLMAPIRequester] # deprecated + requesters: dict[str, requester.ProviderAPIRequester] # deprecated token_mgrs: dict[str, token.TokenManager] # deprecated @@ -28,9 +28,11 @@ class ModelManager: llm_models: list[requester.RuntimeLLMModel] + embedding_models: list[requester.RuntimeEmbeddingModel] + requester_components: list[engine.Component] - requester_dict: dict[str, type[requester.LLMAPIRequester]] # cache + requester_dict: dict[str, type[requester.ProviderAPIRequester]] # cache def __init__(self, ap: app.Application): self.ap = ap @@ -38,6 +40,7 @@ class ModelManager: self.requesters = {} self.token_mgrs = {} self.llm_models = [] + self.embedding_models = [] self.requester_components = [] self.requester_dict = {} @@ -45,7 +48,7 @@ class ModelManager: self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') # forge requester class dict - requester_dict: dict[str, type[requester.LLMAPIRequester]] = {} + requester_dict: dict[str, type[requester.ProviderAPIRequester]] = {} for component in self.requester_components: requester_dict[component.metadata.name] = component.get_python_component_class() @@ -58,13 +61,11 @@ class ModelManager: self.ap.logger.info('Loading models from db...') self.llm_models = [] + self.embedding_models = [] # llm models result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) - llm_models = result.all() - - # load models for llm_model in llm_models: try: await self.load_llm_model(llm_model) @@ -73,11 +74,17 @@ class ModelManager: except Exception as e: self.ap.logger.error(f'Failed to load model {llm_model.uuid}: {e}\n{traceback.format_exc()}') + # embedding models + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.EmbeddingModel)) + embedding_models = result.all() + for embedding_model in embedding_models: + await self.load_embedding_model(embedding_model) + async def init_runtime_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """初始化运行时模型""" + """初始化运行时 LLM 模型""" if isinstance(model_info, sqlalchemy.Row): model_info = persistence_model.LLMModel(**model_info._mapping) elif isinstance(model_info, dict): @@ -101,14 +108,47 @@ class ModelManager: return runtime_llm_model + async def init_runtime_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """初始化运行时 Embedding 模型""" + if isinstance(model_info, sqlalchemy.Row): + model_info = persistence_model.EmbeddingModel(**model_info._mapping) + elif isinstance(model_info, dict): + model_info = persistence_model.EmbeddingModel(**model_info) + + requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config) + + await requester_inst.initialize() + + runtime_embedding_model = requester.RuntimeEmbeddingModel( + model_entity=model_info, + token_mgr=token.TokenManager( + name=model_info.uuid, + tokens=model_info.api_keys, + ), + requester=requester_inst, + ) + + return runtime_embedding_model + async def load_llm_model( self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): - """加载模型""" + """加载 LLM 模型""" runtime_llm_model = await self.init_runtime_llm_model(model_info) self.llm_models.append(runtime_llm_model) + async def load_embedding_model( + self, + model_info: persistence_model.EmbeddingModel | sqlalchemy.Row[persistence_model.EmbeddingModel] | dict, + ): + """加载 Embedding 模型""" + runtime_embedding_model = await self.init_runtime_embedding_model(model_info) + self.embedding_models.append(runtime_embedding_model) + async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated """通过名称获取模型""" for model in self.model_list: @@ -116,23 +156,44 @@ class ModelManager: return model raise ValueError(f'无法确定模型 {name} 的信息') - async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo: - """通过uuid获取模型""" + async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel: + """通过uuid获取 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == uuid: return model - raise ValueError(f'model {uuid} not found') + raise ValueError(f'LLM model {uuid} not found') + + async def get_embedding_model_by_uuid(self, uuid: str) -> requester.RuntimeEmbeddingModel: + """通过uuid获取 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == uuid: + return model + raise ValueError(f'Embedding model {uuid} not found') async def remove_llm_model(self, model_uuid: str): - """移除模型""" + """移除 LLM 模型""" for model in self.llm_models: if model.model_entity.uuid == model_uuid: self.llm_models.remove(model) return - def get_available_requesters_info(self) -> list[dict]: + async def remove_embedding_model(self, model_uuid: str): + """移除 Embedding 模型""" + for model in self.embedding_models: + if model.model_entity.uuid == model_uuid: + self.embedding_models.remove(model) + return + + def get_available_requesters_info(self, model_type: str) -> list[dict]: """获取所有可用的请求器""" - return [component.to_plain_dict() for component in self.requester_components] + if model_type != '': + return [ + component.to_plain_dict() + for component in self.requester_components + if model_type in component.spec['support_type'] + ] + else: + return [component.to_plain_dict() for component in self.requester_components] def get_available_requester_info_by_name(self, name: str) -> dict | None: """通过名称获取请求器信息""" diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 244f4c82..17697cdb 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -20,22 +20,45 @@ class RuntimeLLMModel: token_mgr: token.TokenManager """api key管理器""" - requester: LLMAPIRequester + requester: ProviderAPIRequester """请求器实例""" def __init__( self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, - requester: LLMAPIRequester, + requester: ProviderAPIRequester, ): self.model_entity = model_entity self.token_mgr = token_mgr self.requester = requester -class LLMAPIRequester(metaclass=abc.ABCMeta): - """LLM API请求器""" +class RuntimeEmbeddingModel: + """运行时 Embedding 模型""" + + model_entity: persistence_model.EmbeddingModel + """模型数据""" + + token_mgr: token.TokenManager + """api key管理器""" + + requester: ProviderAPIRequester + """请求器实例""" + + def __init__( + self, + model_entity: persistence_model.EmbeddingModel, + token_mgr: token.TokenManager, + requester: ProviderAPIRequester, + ): + self.model_entity = model_entity + self.token_mgr = token_mgr + self.requester = requester + + +class ProviderAPIRequester(metaclass=abc.ABCMeta): + """Provider API请求器""" name: str = None @@ -74,3 +97,22 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): llm_entities.Message: 返回消息对象 """ pass + + async def invoke_embedding( + self, + model: RuntimeEmbeddingModel, + input_text: list[str], + extra_args: dict[str, typing.Any] = {}, + ) -> list[list[float]]: + """调用 Embedding API + + Args: + query (core_entities.Query): 请求上下文 + model (RuntimeEmbeddingModel): 使用的模型信息 + input_text (list[str]): 输入文本 + extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. + + Returns: + list[list[float]]: 返回的 embedding 向量 + """ + pass diff --git a/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml b/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml index 2d9df778..754a9078 100644 --- a/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/302aichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./302aichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 38573854..b195ae51 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -15,7 +15,7 @@ from ...tools import entities as tools_entities from ....utils import image -class AnthropicMessages(requester.LLMAPIRequester): +class AnthropicMessages(requester.ProviderAPIRequester): """Anthropic Messages API 请求器""" client: anthropic.AsyncAnthropic diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml index c124fed9..7dbcf3ed 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./anthropicmsgs.py diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml index 24beb915..10aae30f 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./bailianchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 513086e5..aaaf3751 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -13,7 +13,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class OpenAIChatCompletions(requester.LLMAPIRequester): +class OpenAIChatCompletions(requester.ProviderAPIRequester): """OpenAI ChatCompletion API 请求器""" client: openai.AsyncClient @@ -141,3 +141,39 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def invoke_embedding( + self, + model: requester.RuntimeEmbeddingModel, + input_text: list[str], + extra_args: dict[str, typing.Any] = {}, + ) -> list[list[float]]: + """调用 Embedding API""" + self.client.api_key = model.token_mgr.get_token() + + args = { + 'model': model.model_entity.name, + 'input': input_text, + } + + if model.model_entity.extra_args: + args.update(model.model_entity.extra_args) + + args.update(extra_args) + + try: + resp = await self.client.embeddings.create(**args) + + return [d.embedding for d in resp.data] + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + except openai.BadRequestError as e: + raise errors.RequesterError(f'请求参数错误: {e.message}') + except openai.AuthenticationError as e: + raise errors.RequesterError(f'无效的 api-key: {e.message}') + except openai.NotFoundError as e: + raise errors.RequesterError(f'请求路径错误: {e.message}') + except openai.RateLimitError as e: + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') + except openai.APIError as e: + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.yaml b/pkg/provider/modelmgr/requesters/chatcmpl.yaml index 908b30ac..ff0de6f9 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/chatcmpl.yaml @@ -22,6 +22,9 @@ spec: type: integer required: true default: 120 + support_type: + - llm + - text-embedding execution: python: path: ./chatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml b/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml index ca57c31c..2b7f9a70 100644 --- a/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/compsharechatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./compsharechatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml index ea2c7eea..6f320e66 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./deepseekchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml index 6bfc085e..73fca19c 100644 --- a/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/geminichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./geminichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml index a18675a1..3a79bb49 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./giteeaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml index 893235b2..fbe57dad 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./lmstudiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index b8868f4d..4708f671 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -14,7 +14,7 @@ from ... import entities as llm_entities from ...tools import entities as tools_entities -class ModelScopeChatCompletions(requester.LLMAPIRequester): +class ModelScopeChatCompletions(requester.ProviderAPIRequester): """ModelScope ChatCompletion API 请求器""" client: openai.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml index a641a672..a926d889 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./modelscopechatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml index f3ae73c8..52f7bcda 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./moonshotchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index 2ea4bb7d..1456515f 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -17,7 +17,7 @@ from ....core import entities as core_entities REQUESTER_NAME: str = 'ollama-chat' -class OllamaChatCompletions(requester.LLMAPIRequester): +class OllamaChatCompletions(requester.ProviderAPIRequester): """Ollama平台 ChatCompletion API请求器""" client: ollama.AsyncClient diff --git a/pkg/provider/modelmgr/requesters/ollamachat.yaml b/pkg/provider/modelmgr/requesters/ollamachat.yaml index 01435775..f4c4bf5a 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.yaml +++ b/pkg/provider/modelmgr/requesters/ollamachat.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./ollamachat.py diff --git a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml index 2ecee6cc..ea35bce6 100644 --- a/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/openrouterchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./openrouterchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml index 9f201aa9..a5a3421c 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.yaml @@ -29,6 +29,8 @@ spec: type: int required: true default: 120 + support_type: + - llm execution: python: path: ./ppiochatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml index 19b3dcc3..3872cb6f 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./siliconflowchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml index 402f04e7..c711ef2d 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./volcarkchatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml index 29db4eb3..2769a402 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./xaichatcmpl.py diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml index a05184ef..34539d95 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml @@ -22,6 +22,8 @@ spec: type: integer required: true default: 120 + support_type: + - llm execution: python: path: ./zhipuaichatcmpl.py diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 7d5e04c5..1d3e88ac 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -1,13 +1,28 @@ from __future__ import annotations import json +import copy import typing - from .. import runner from ...core import entities as core_entities from .. import entities as llm_entities +rag_combined_prompt_template = """ +The following are relevant context entries retrieved from the knowledge base. +Please use them to answer the user's message. +Respond in the same language as the user's input. + + +{rag_context} + + + +{user_message} + +""" + + @runner.runner_class('local-agent') class LocalAgentRunner(runner.RequestRunner): """本地Agent请求运行器""" @@ -16,7 +31,54 @@ class LocalAgentRunner(runner.RequestRunner): """运行请求""" pending_tool_calls = [] - req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] + kb_uuid = query.pipeline_config['ai']['local-agent']['knowledge-base'] + + if kb_uuid == '__none__': + kb_uuid = None + + user_message = copy.deepcopy(query.user_message) + + user_message_text = '' + + if isinstance(user_message.content, str): + user_message_text = user_message.content + elif isinstance(user_message.content, list): + for ce in user_message.content: + if ce.type == 'text': + user_message_text += ce.text + break + + if kb_uuid and user_message_text: + # only support text for now + kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + + if not kb: + self.ap.logger.warning(f'Knowledge base {kb_uuid} not found') + raise ValueError(f'Knowledge base {kb_uuid} not found') + + result = await kb.retrieve(user_message_text) + + final_user_message_text = '' + + if result: + rag_context = '\n\n'.join( + f'[{i + 1}] {entry.metadata.get("text", "")}' for i, entry in enumerate(result) + ) + final_user_message_text = rag_combined_prompt_template.format( + rag_context=rag_context, user_message=user_message_text + ) + + else: + final_user_message_text = user_message_text + + self.ap.logger.debug(f'Final user message text: {final_user_message_text}') + + for ce in user_message.content: + if ce.type == 'text': + ce.text = final_user_message_text + break + + req_messages = query.prompt.messages.copy() + query.messages.copy() + [user_message] # 首次请求 msg = await query.use_llm_model.requester.invoke_llm( diff --git a/pkg/rag/knowledge/kbmgr.py b/pkg/rag/knowledge/kbmgr.py new file mode 100644 index 00000000..a9e7e57a --- /dev/null +++ b/pkg/rag/knowledge/kbmgr.py @@ -0,0 +1,212 @@ +from __future__ import annotations +import traceback +import uuid +from .services import parser, chunker +from pkg.core import app +from pkg.rag.knowledge.services.embedder import Embedder +from pkg.rag.knowledge.services.retriever import Retriever +import sqlalchemy +from ...entity.persistence import rag as persistence_rag +from pkg.core import taskmgr +from ...entity.rag import retriever as retriever_entities + + +class RuntimeKnowledgeBase: + ap: app.Application + + knowledge_base_entity: persistence_rag.KnowledgeBase + + parser: parser.FileParser + + chunker: chunker.Chunker + + embedder: Embedder + + retriever: Retriever + + def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase): + self.ap = ap + self.knowledge_base_entity = knowledge_base_entity + self.parser = parser.FileParser(ap=self.ap) + self.chunker = chunker.Chunker(ap=self.ap) + self.embedder = Embedder(ap=self.ap) + self.retriever = Retriever(ap=self.ap) + # 传递kb_id给retriever + self.retriever.kb_id = knowledge_base_entity.uuid + + async def initialize(self): + pass + + async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext): + try: + # set file status to processing + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.File) + .where(persistence_rag.File.uuid == file.uuid) + .values(status='processing') + ) + + task_context.set_current_action('Parsing file') + # parse file + text = await self.parser.parse(file.file_name, file.extension) + if not text: + raise Exception(f'No text extracted from file {file.file_name}') + + task_context.set_current_action('Chunking file') + # chunk file + chunks_texts = await self.chunker.chunk(text) + if not chunks_texts: + raise Exception(f'No chunks extracted from file {file.file_name}') + + task_context.set_current_action('Embedding chunks') + + embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid( + self.knowledge_base_entity.embedding_model_uuid + ) + # embed chunks + await self.embedder.embed_and_store( + kb_id=self.knowledge_base_entity.uuid, + file_id=file.uuid, + chunks=chunks_texts, + embedding_model=embedding_model, + ) + + # set file status to completed + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.File) + .where(persistence_rag.File.uuid == file.uuid) + .values(status='completed') + ) + + except Exception as e: + self.ap.logger.error(f'Error storing file {file.uuid}: {e}') + traceback.print_exc() + # set file status to failed + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_rag.File) + .where(persistence_rag.File.uuid == file.uuid) + .values(status='failed') + ) + + raise + + async def store_file(self, file_id: str) -> str: + # pre checking + if not await self.ap.storage_mgr.storage_provider.exists(file_id): + raise Exception(f'File {file_id} not found') + + file_uuid = str(uuid.uuid4()) + kb_id = self.knowledge_base_entity.uuid + file_name = file_id + extension = file_name.split('.')[-1] + + file_obj_data = { + 'uuid': file_uuid, + 'kb_id': kb_id, + 'file_name': file_name, + 'extension': extension, + 'status': 'pending', + } + + file_obj = persistence_rag.File(**file_obj_data) + + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(file_obj_data)) + + # run background task asynchronously + ctx = taskmgr.TaskContext.new() + wrapper = self.ap.task_mgr.create_user_task( + self._store_file_task(file_obj, task_context=ctx), + kind='knowledge-operation', + name=f'knowledge-store-file-{file_id}', + label=f'Store file {file_id}', + context=ctx, + ) + return wrapper.id + + async def retrieve(self, query: str) -> list[retriever_entities.RetrieveResultEntry]: + embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid( + self.knowledge_base_entity.embedding_model_uuid + ) + return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model) + + async def delete_file(self, file_id: str): + # delete vector + await self.ap.vector_db_mgr.vector_db.delete_by_file_id(self.knowledge_base_entity.uuid, file_id) + + # delete chunk + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.Chunk).where(persistence_rag.Chunk.file_id == file_id) + ) + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id) + ) + + async def dispose(self): + await self.ap.vector_db_mgr.vector_db.delete_collection(self.knowledge_base_entity.uuid) + + +class RAGManager: + ap: app.Application + + knowledge_bases: list[RuntimeKnowledgeBase] + + def __init__(self, ap: app.Application): + self.ap = ap + self.knowledge_bases = [] + + async def initialize(self): + await self.load_knowledge_bases_from_db() + + async def load_knowledge_bases_from_db(self): + self.ap.logger.info('Loading knowledge bases from db...') + + self.knowledge_bases = [] + + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase)) + + knowledge_bases = result.all() + + for knowledge_base in knowledge_bases: + try: + await self.load_knowledge_base(knowledge_base) + except Exception as e: + self.ap.logger.error( + f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}' + ) + + async def load_knowledge_base( + self, + knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict, + ) -> RuntimeKnowledgeBase: + if isinstance(knowledge_base_entity, sqlalchemy.Row): + knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping) + elif isinstance(knowledge_base_entity, dict): + knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity) + + runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity) + + await runtime_knowledge_base.initialize() + + self.knowledge_bases.append(runtime_knowledge_base) + + return runtime_knowledge_base + + async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> RuntimeKnowledgeBase | None: + for kb in self.knowledge_bases: + if kb.knowledge_base_entity.uuid == kb_uuid: + return kb + return None + + async def remove_knowledge_base_from_runtime(self, kb_uuid: str): + for kb in self.knowledge_bases: + if kb.knowledge_base_entity.uuid == kb_uuid: + self.knowledge_bases.remove(kb) + return + + async def delete_knowledge_base(self, kb_uuid: str): + for kb in self.knowledge_bases: + if kb.knowledge_base_entity.uuid == kb_uuid: + await kb.dispose() + self.knowledge_bases.remove(kb) + return diff --git a/pkg/rag/knowledge/services/__init__.py b/pkg/rag/knowledge/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/rag/knowledge/services/base_service.py b/pkg/rag/knowledge/services/base_service.py new file mode 100644 index 00000000..0f71a508 --- /dev/null +++ b/pkg/rag/knowledge/services/base_service.py @@ -0,0 +1,15 @@ +# 封装异步操作 +import asyncio + + +class BaseService: + def __init__(self): + pass + + async def _run_sync(self, func, *args, **kwargs): + """ + 在单独的线程中运行同步函数。 + 如果第一个参数是 session,则在 to_thread 中获取新的 session。 + """ + + return await asyncio.to_thread(func, *args, **kwargs) diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py new file mode 100644 index 00000000..f169d5f1 --- /dev/null +++ b/pkg/rag/knowledge/services/chunker.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import json +from typing import List +from pkg.rag.knowledge.services import base_service +from pkg.core import app + + +class Chunker(base_service.BaseService): + """ + A class for splitting long texts into smaller, overlapping chunks. + """ + + def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50): + self.ap = ap + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + if self.chunk_overlap >= self.chunk_size: + self.ap.logger.warning( + 'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.' + ) + + def _split_text_sync(self, text: str) -> List[str]: + """ + Synchronously splits a long text into chunks with specified overlap. + This is a CPU-bound operation, intended to be run in a separate thread. + """ + if not text: + return [] + # words = text.split() + # chunks = [] + # current_chunk = [] + + # for word in words: + # current_chunk.append(word) + # if len(current_chunk) > self.chunk_size: + # chunks.append(" ".join(current_chunk[:self.chunk_size])) + # current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:] + + # if current_chunk: + # chunks.append(" ".join(current_chunk)) + + # A more robust chunking strategy (e.g., using recursive character text splitter) + from langchain.text_splitter import RecursiveCharacterTextSplitter + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + length_function=len, + is_separator_regex=False, + ) + return text_splitter.split_text(text) + + async def chunk(self, text: str) -> List[str]: + """ + Asynchronously chunks a given text into smaller pieces. + """ + self.ap.logger.info(f'Chunking text (length: {len(text)})...') + # Run the synchronous splitting logic in a separate thread + chunks = await self._run_sync(self._split_text_sync, text) + self.ap.logger.info(f'Text chunked into {len(chunks)} pieces.') + self.ap.logger.debug(f'Chunks: {json.dumps(chunks, indent=4, ensure_ascii=False)}') + return chunks diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py new file mode 100644 index 00000000..a0ae3d49 --- /dev/null +++ b/pkg/rag/knowledge/services/embedder.py @@ -0,0 +1,47 @@ +from __future__ import annotations +import uuid +from typing import List +from pkg.rag.knowledge.services.base_service import BaseService +from ....entity.persistence import rag as persistence_rag +from ....core import app +from ....provider.modelmgr.requester import RuntimeEmbeddingModel +import sqlalchemy + + +class Embedder(BaseService): + def __init__(self, ap: app.Application) -> None: + super().__init__() + self.ap = ap + + async def embed_and_store( + self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel + ) -> list[persistence_rag.Chunk]: + # save chunk to db + chunk_entities: list[persistence_rag.Chunk] = [] + chunk_ids: list[str] = [] + + for chunk_text in chunks: + chunk_uuid = str(uuid.uuid4()) + chunk_ids.append(chunk_uuid) + chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text) + chunk_entities.append(chunk_entity) + + chunk_dicts = [ + self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities + ] + + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts)) + + # get embeddings + embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding( + model=embedding_model, + input_text=chunks, + extra_args={}, # TODO: add extra args + ) + + # save embeddings to vdb + await self.ap.vector_db_mgr.vector_db.add_embeddings(kb_id, chunk_ids, embeddings_list, chunk_dicts) + + self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.') + + return chunk_entities diff --git a/pkg/rag/knowledge/services/parser.py b/pkg/rag/knowledge/services/parser.py new file mode 100644 index 00000000..004dbdaa --- /dev/null +++ b/pkg/rag/knowledge/services/parser.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import PyPDF2 +import io +from docx import Document +import chardet +from typing import Union, Callable, Any +import markdown +from bs4 import BeautifulSoup +import re +import asyncio # Import asyncio for async operations +from pkg.core import app + + +class FileParser: + """ + A robust file parser class to extract text content from various document formats. + It supports TXT, PDF, DOCX, XLSX, CSV, Markdown, HTML, and EPUB files. + All core file reading operations are designed to be run synchronously in a thread pool + to avoid blocking the asyncio event loop. + """ + + def __init__(self, ap: app.Application): + self.ap = ap + + async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any: + """ + Runs a synchronous function in a separate thread to prevent blocking the event loop. + This is a general utility method for wrapping blocking I/O operations. + """ + try: + return await asyncio.to_thread(sync_func, *args, **kwargs) + except Exception as e: + self.ap.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}') + raise + + async def parse(self, file_name: str, extension: str) -> Union[str, None]: + """ + Parses the file based on its extension and returns the extracted text content. + This is the main asynchronous entry point for parsing. + + Args: + file_name (str): The name of the file to be parsed, get from ap.storage_mgr + + Returns: + Union[str, None]: The extracted text content as a single string, or None if parsing fails. + """ + + file_extension = extension.lower() + parser_method = getattr(self, f'_parse_{file_extension}', None) + + if parser_method is None: + self.ap.logger.error(f'Unsupported file format: {file_extension} for file {file_name}') + return None + + try: + # Pass file_path to the specific parser methods + return await parser_method(file_name) + except Exception as e: + self.ap.logger.error(f'Failed to parse {file_extension} file {file_name}: {e}') + return None + + # --- Helper for reading files with encoding detection --- + async def _read_file_content(self, file_name: str) -> Union[str, bytes]: + """ + Reads a file with automatic encoding detection, ensuring the synchronous + file read operation runs in a separate thread. + """ + + # def _read_sync(): + # with open(file_path, 'rb') as file: + # raw_data = file.read() + # detected = chardet.detect(raw_data) + # encoding = detected['encoding'] or 'utf-8' + + # if mode == 'r': + # return raw_data.decode(encoding, errors='ignore') + # return raw_data # For binary mode + + # return await self._run_sync(_read_sync) + file_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + + detected = chardet.detect(file_bytes) + encoding = detected['encoding'] or 'utf-8' + + return file_bytes.decode(encoding, errors='ignore') + + # --- Specific Parser Methods --- + + async def _parse_txt(self, file_name: str) -> str: + """Parses a TXT file and returns its content.""" + self.ap.logger.info(f'Parsing TXT file: {file_name}') + return await self._read_file_content(file_name) + + async def _parse_pdf(self, file_name: str) -> str: + """Parses a PDF file and returns its text content.""" + self.ap.logger.info(f'Parsing PDF file: {file_name}') + + # def _parse_pdf_sync(): + # text_content = [] + # with open(file_name, 'rb') as file: + # pdf_reader = PyPDF2.PdfReader(file) + # for page in pdf_reader.pages: + # text = page.extract_text() + # if text: + # text_content.append(text) + # return '\n'.join(text_content) + + # return await self._run_sync(_parse_pdf_sync) + + pdf_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + + def _parse_pdf_sync(): + pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes)) + text_content = [] + for page in pdf_reader.pages: + text = page.extract_text() + if text: + text_content.append(text) + return '\n'.join(text_content) + + return await self._run_sync(_parse_pdf_sync) + + async def _parse_docx(self, file_name: str) -> str: + """Parses a DOCX file and returns its text content.""" + self.ap.logger.info(f'Parsing DOCX file: {file_name}') + + docx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + + def _parse_docx_sync(): + doc = Document(io.BytesIO(docx_bytes)) + text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()] + return '\n'.join(text_content) + + return await self._run_sync(_parse_docx_sync) + + async def _parse_doc(self, file_name: str) -> str: + """Handles .doc files, explicitly stating lack of direct support.""" + self.ap.logger.warning(f'Direct .doc parsing is not supported for {file_name}. Please convert to .docx first.') + raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.') + + # async def _parse_xlsx(self, file_name: str) -> str: + # """Parses an XLSX file, returning text from all sheets.""" + # self.ap.logger.info(f'Parsing XLSX file: {file_name}') + + # xlsx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + + # def _parse_xlsx_sync(): + # excel_file = pd.ExcelFile(io.BytesIO(xlsx_bytes)) + # all_sheet_content = [] + # for sheet_name in excel_file.sheet_names: + # df = pd.read_excel(io.BytesIO(xlsx_bytes), sheet_name=sheet_name) + # sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n' + # all_sheet_content.append(sheet_text) + # return '\n'.join(all_sheet_content) + + # return await self._run_sync(_parse_xlsx_sync) + + # async def _parse_csv(self, file_name: str) -> str: + # """Parses a CSV file and returns its content as a string.""" + # self.ap.logger.info(f'Parsing CSV file: {file_name}') + + # csv_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + + # def _parse_csv_sync(): + # # pd.read_csv can often detect encoding, but explicit detection is safer + # # raw_data = self._read_file_content( + # # file_name, mode='rb' + # # ) # Note: this will need to be await outside this sync function + # # _ = raw_data + # # For simplicity, we'll let pandas handle encoding internally after a raw read. + # # A more robust solution might pass encoding directly to pd.read_csv after detection. + # detected = chardet.detect(io.BytesIO(csv_bytes)) + # encoding = detected['encoding'] or 'utf-8' + # df = pd.read_csv(io.BytesIO(csv_bytes), encoding=encoding) + # return df.to_string(index=False) + + # return await self._run_sync(_parse_csv_sync) + + async def _parse_md(self, file_name: str) -> str: + """Parses a Markdown file, converting it to structured plain text.""" + self.ap.logger.info(f'Parsing Markdown file: {file_name}') + + md_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + + def _parse_markdown_sync(): + md_content = io.BytesIO(md_bytes).read().decode('utf-8', errors='ignore') + html_content = markdown.markdown( + md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code'] + ) + soup = BeautifulSoup(html_content, 'html.parser') + text_parts = [] + for element in soup.children: + if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: + level = int(element.name[1]) + text_parts.append('#' * level + ' ' + element.get_text().strip()) + elif element.name == 'p': + text = element.get_text().strip() + if text: + text_parts.append(text) + elif element.name in ['ul', 'ol']: + for li in element.find_all('li'): + text_parts.append(f'* {li.get_text().strip()}') + elif element.name == 'pre': + code_block = element.get_text().strip() + if code_block: + text_parts.append(f'```\n{code_block}\n```') + elif element.name == 'table': + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + if table_str: + text_parts.append(table_str) + elif element.name: + text = element.get_text(separator=' ', strip=True) + if text: + text_parts.append(text) + cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) + return cleaned_text.strip() + + return await self._run_sync(_parse_markdown_sync) + + async def _parse_html(self, file_name: str) -> str: + """Parses an HTML file, extracting structured plain text.""" + self.ap.logger.info(f'Parsing HTML file: {file_name}') + + html_bytes = await self.ap.storage_mgr.storage_provider.load(file_name) + + def _parse_html_sync(): + html_content = io.BytesIO(html_bytes).read().decode('utf-8', errors='ignore') + soup = BeautifulSoup(html_content, 'html.parser') + for script_or_style in soup(['script', 'style']): + script_or_style.decompose() + text_parts = [] + for element in soup.body.children if soup.body else soup.children: + if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: + level = int(element.name[1]) + text_parts.append('#' * level + ' ' + element.get_text().strip()) + elif element.name == 'p': + text = element.get_text().strip() + if text: + text_parts.append(text) + elif element.name in ['ul', 'ol']: + for li in element.find_all('li'): + text = li.get_text().strip() + if text: + text_parts.append(f'* {text}') + elif element.name == 'table': + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + if table_str: + text_parts.append(table_str) + elif element.name: + text = element.get_text(separator=' ', strip=True) + if text: + text_parts.append(text) + cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) + return cleaned_text.strip() + + return await self._run_sync(_parse_html_sync) + + def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int): + """Recursively adds TOC items to text_content (synchronous helper).""" + indent = ' ' * level + for item in toc_list: + if isinstance(item, tuple): + chapter, subchapters = item + text_content.append(f'{indent}- {chapter.title}') + self._add_toc_items_sync(subchapters, text_content, level + 1) + else: + text_content.append(f'{indent}- {item.title}') + + def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str: + """Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous).""" + headers = [th.get_text().strip() for th in table_element.find_all('th')] + rows = [] + for tr in table_element.find_all('tr'): + cells = [td.get_text().strip() for td in tr.find_all('td')] + if cells: + rows.append(cells) + + if not headers and not rows: + return '' + + table_lines = [] + if headers: + table_lines.append(' | '.join(headers)) + table_lines.append(' | '.join(['---'] * len(headers))) + + for row_cells in rows: + padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells + table_lines.append(' | '.join(padded_cells)) + + return '\n'.join(table_lines) diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py new file mode 100644 index 00000000..73c7edaa --- /dev/null +++ b/pkg/rag/knowledge/services/retriever.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from . import base_service +from ....core import app +from ....provider.modelmgr.requester import RuntimeEmbeddingModel +from ....entity.rag import retriever as retriever_entities + + +class Retriever(base_service.BaseService): + def __init__(self, ap: app.Application): + super().__init__() + self.ap = ap + + async def retrieve( + self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5 + ) -> list[retriever_entities.RetrieveResultEntry]: + self.ap.logger.info( + f"Retrieving for query: '{query[:10]}' with k={k} using {embedding_model.model_entity.uuid}" + ) + + query_embedding: list[float] = await embedding_model.requester.invoke_embedding( + model=embedding_model, + input_text=[query], + extra_args={}, # TODO: add extra args + ) + + chroma_results = await self.ap.vector_db_mgr.vector_db.search(kb_id, query_embedding[0], k) + + # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' + matched_chroma_ids = chroma_results.get('ids', [[]])[0] + distances = chroma_results.get('distances', [[]])[0] + chroma_metadatas = chroma_results.get('metadatas', [[]])[0] + + if not matched_chroma_ids: + self.ap.logger.info('No relevant chunks found in Chroma.') + return [] + + result: list[retriever_entities.RetrieveResultEntry] = [] + + for i, id in enumerate(matched_chroma_ids): + entry = retriever_entities.RetrieveResultEntry( + id=id, + metadata=chroma_metadatas[i], + distance=distances[i], + ) + result.append(entry) + + return result diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index fc1d1f49..4886d186 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1,7 +1,7 @@ semantic_version = 'v4.0.9' -required_database_version = 3 -"""标记本版本所需要的数据库结构版本,用于判断数据库迁移""" +required_database_version = 4 +"""Tag the version of the database schema, used to check if the database needs to be migrated""" debug_mode = False diff --git a/pkg/vector/__init__.py b/pkg/vector/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/vector/mgr.py b/pkg/vector/mgr.py new file mode 100644 index 00000000..ea198ac2 --- /dev/null +++ b/pkg/vector/mgr.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ..core import app +from .vdb import VectorDatabase +from .vdbs.chroma import ChromaVectorDatabase + + +class VectorDBManager: + ap: app.Application + vector_db: VectorDatabase = None + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + # 初始化 Chroma 向量数据库(可扩展为多种实现) + if self.vector_db is None: + self.vector_db = ChromaVectorDatabase(self.ap) diff --git a/pkg/vector/vdb.py b/pkg/vector/vdb.py new file mode 100644 index 00000000..73a3cc0e --- /dev/null +++ b/pkg/vector/vdb.py @@ -0,0 +1,37 @@ +from __future__ import annotations +import abc +from typing import Any, Dict +import numpy as np + + +class VectorDatabase(abc.ABC): + @abc.abstractmethod + async def add_embeddings( + self, + collection: str, + ids: list[str], + embeddings_list: list[list[float]], + metadatas: list[dict[str, Any]], + documents: list[str], + ) -> None: + """向指定 collection 添加向量数据。""" + pass + + @abc.abstractmethod + async def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]: + """在指定 collection 中检索最相似的向量。""" + pass + + @abc.abstractmethod + async def delete_by_file_id(self, collection: str, file_id: str) -> None: + """根据 file_id 删除指定 collection 中的向量。""" + pass + + @abc.abstractmethod + async def get_or_create_collection(self, collection: str): + """获取或创建 collection。""" + pass + + @abc.abstractmethod + async def delete_collection(self, collection: str): + pass diff --git a/pkg/vector/vdbs/__init__.py b/pkg/vector/vdbs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/vector/vdbs/chroma.py b/pkg/vector/vdbs/chroma.py new file mode 100644 index 00000000..41ab7d36 --- /dev/null +++ b/pkg/vector/vdbs/chroma.py @@ -0,0 +1,61 @@ +from __future__ import annotations +import asyncio +from typing import Any +from chromadb import PersistentClient +from pkg.vector.vdb import VectorDatabase +from pkg.core import app +import chromadb +import chromadb.errors + + +class ChromaVectorDatabase(VectorDatabase): + def __init__(self, ap: app.Application, base_path: str = './data/chroma'): + self.ap = ap + self.client = PersistentClient(path=base_path) + self._collections = {} + + async def get_or_create_collection(self, collection: str) -> chromadb.Collection: + if collection not in self._collections: + self._collections[collection] = await asyncio.to_thread( + self.client.get_or_create_collection, name=collection + ) + self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.") + return self._collections[collection] + + async def add_embeddings( + self, + collection: str, + ids: list[str], + embeddings_list: list[list[float]], + metadatas: list[dict[str, Any]], + ) -> None: + col = await self.get_or_create_collection(collection) + await asyncio.to_thread(col.add, embeddings=embeddings_list, ids=ids, metadatas=metadatas) + self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.") + + async def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]: + col = await self.get_or_create_collection(collection) + results = await asyncio.to_thread( + col.query, + query_embeddings=query_embedding, + n_results=k, + include=['metadatas', 'distances', 'documents'], + ) + self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.") + return results + + async def delete_by_file_id(self, collection: str, file_id: str) -> None: + col = await self.get_or_create_collection(collection) + await asyncio.to_thread(col.delete, where={'file_id': file_id}) + self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with file_id: {file_id}") + + async def delete_collection(self, collection: str): + if collection in self._collections: + del self._collections[collection] + + try: + await asyncio.to_thread(self.client.delete_collection, name=collection) + except chromadb.errors.NotFoundError: + self.ap.logger.warning(f"Chroma collection '{collection}' not found.") + return + self.ap.logger.info(f"Chroma collection '{collection}' deleted.") diff --git a/pyproject.toml b/pyproject.toml index a6a6d779..504cad29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,16 @@ dependencies = [ "ruff>=0.11.9", "pre-commit>=4.2.0", "uv>=0.7.11", + "PyPDF2>=3.0.1", + "python-docx>=1.1.0", + "pandas>=2.2.2", + "chardet>=5.2.0", + "markdown>=3.6", + "beautifulsoup4>=4.12.3", + "ebooklib>=0.18", + "html2text>=2024.2.26", + "langchain>=0.2.0", + "chromadb>=0.4.24", ] keywords = [ "bot", diff --git a/templates/default-pipeline-config.json b/templates/default-pipeline-config.json index 796c6356..d06e4661 100644 --- a/templates/default-pipeline-config.json +++ b/templates/default-pipeline-config.json @@ -44,7 +44,8 @@ "role": "system", "content": "You are a helpful assistant." } - ] + ], + "knowledge-base": "" }, "dify-service-api": { "base-url": "https://api.dify.ai/v1", diff --git a/templates/metadata/pipeline/ai.yaml b/templates/metadata/pipeline/ai.yaml index 90732dc8..ffbefe63 100644 --- a/templates/metadata/pipeline/ai.yaml +++ b/templates/metadata/pipeline/ai.yaml @@ -68,6 +68,16 @@ stages: zh_Hans: 除非您了解消息结构,否则请只使用 system 单提示词 type: prompt-editor required: true + - name: knowledge-base + label: + en_US: Knowledge Base + zh_Hans: 知识库 + description: + en_US: Configure the knowledge base to use for the agent, if not selected, the agent will directly use the LLM to reply + zh_Hans: 配置用于提升回复质量的知识库,若不选择,则直接使用大模型回复 + type: knowledge-base-selector + required: false + default: '' - name: dify-service-api label: en_US: Dify Service API @@ -298,3 +308,4 @@ stages: type: string required: false default: 'response' + diff --git a/web/package-lock.json b/web/package-lock.json index ee9b5767..fcc17852 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -12,23 +12,27 @@ "@dnd-kit/sortable": "^10.0.0", "@hookform/resolvers": "^5.0.1", "@radix-ui/react-checkbox": "^1.3.1", - "@radix-ui/react-dialog": "^1.1.13", + "@radix-ui/react-dialog": "^1.1.14", "@radix-ui/react-hover-card": "^1.1.13", "@radix-ui/react-label": "^2.1.6", "@radix-ui/react-popover": "^1.1.14", "@radix-ui/react-scroll-area": "^1.2.9", "@radix-ui/react-select": "^2.2.4", - "@radix-ui/react-slot": "^1.2.2", + "@radix-ui/react-separator": "^1.1.7", + "@radix-ui/react-slot": "^1.2.3", "@radix-ui/react-switch": "^1.2.4", "@radix-ui/react-tabs": "^1.1.11", "@radix-ui/react-toggle": "^1.1.8", "@radix-ui/react-toggle-group": "^1.1.9", + "@radix-ui/react-tooltip": "^1.2.7", "@tailwindcss/postcss": "^4.1.5", + "@tanstack/react-table": "^8.21.3", "axios": "^1.8.4", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "i18next": "^25.1.2", "i18next-browser-languagedetector": "^8.1.0", + "input-otp": "^1.4.2", "lodash": "^4.17.21", "lucide-react": "^0.507.0", "next": "15.2.4", @@ -1037,6 +1041,24 @@ } } }, + "node_modules/@radix-ui/react-collection/node_modules/@radix-ui/react-slot": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz", + "integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-compose-refs": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.2.tgz", @@ -1068,22 +1090,22 @@ } }, "node_modules/@radix-ui/react-dialog": { - "version": "1.1.13", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.13.tgz", - "integrity": "sha512-ARFmqUyhIVS3+riWzwGTe7JLjqwqgnODBUZdqpWar/z1WFs9z76fuOs/2BOWCR+YboRn4/WN9aoaGVwqNRr8VA==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.14.tgz", + "integrity": "sha512-+CpweKjqpzTmwRwcYECQcNYbI8V9VSQt0SNFKeEBLgfucbsLssU6Ppq7wUdNXEGb573bMjFhVjKVll8rmV6zMw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.9", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.6", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-portal": "1.1.8", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.2", - "@radix-ui/react-slot": "1.2.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-controllable-state": "1.2.2", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" @@ -1103,6 +1125,105 @@ } } }, + "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-dismissable-layer": { + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.10.tgz", + "integrity": "sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-escape-keydown": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-focus-scope": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz", + "integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-portal": { + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz", + "integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-direction": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@radix-ui/react-direction/-/react-direction-1.1.1.tgz", @@ -1448,24 +1569,6 @@ } } }, - "node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-slot": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", - "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", - "license": "MIT", - "dependencies": { - "@radix-ui/react-compose-refs": "1.1.2" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, "node_modules/@radix-ui/react-popper": { "version": "1.2.6", "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.6.tgz", @@ -1569,6 +1672,24 @@ } } }, + "node_modules/@radix-ui/react-primitive/node_modules/@radix-ui/react-slot": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz", + "integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-roving-focus": { "version": "1.1.9", "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.9.tgz", @@ -1654,24 +1775,6 @@ } } }, - "node_modules/@radix-ui/react-scroll-area/node_modules/@radix-ui/react-slot": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", - "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", - "license": "MIT", - "dependencies": { - "@radix-ui/react-compose-refs": "1.1.2" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, "node_modules/@radix-ui/react-select": { "version": "2.2.4", "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.4.tgz", @@ -1715,7 +1818,7 @@ } } }, - "node_modules/@radix-ui/react-slot": { + "node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-slot": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz", "integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==", @@ -1733,6 +1836,70 @@ } } }, + "node_modules/@radix-ui/react-separator": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.7.tgz", + "integrity": "sha512-0HEb8R9E8A+jZjvmFCy/J4xhbXy3TV+9XSnGJ3KvTtjlIUy/YQ/p6UYZvi7YbeoeXdyU9+Y3scizK6hkY37baA==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-separator/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slot": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", + "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-switch": { "version": "1.2.4", "resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.2.4.tgz", @@ -1846,6 +2013,192 @@ } } }, + "node_modules/@radix-ui/react-tooltip": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.7.tgz", + "integrity": "sha512-Ap+fNYwKTYJ9pzqW+Xe2HtMRbQ/EeWkj2qykZ6SuEV4iS/o1bZI5ssJbk4D2r8XuDuOBVz/tIx2JObtuqU+5Zw==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", + "@radix-ui/react-presence": "1.1.4", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", + "@radix-ui/react-use-controllable-state": "1.2.2", + "@radix-ui/react-visually-hidden": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-arrow": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.7.tgz", + "integrity": "sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-dismissable-layer": { + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.10.tgz", + "integrity": "sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-escape-keydown": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-popper": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.7.tgz", + "integrity": "sha512-IUFAccz1JyKcf/RjB552PlWwxjeCJB8/4KxT7EhBHOJM+mN7LdW+B3kacJXILm32xawcMMjb2i0cIZpo+f9kiQ==", + "license": "MIT", + "dependencies": { + "@floating-ui/react-dom": "^2.0.0", + "@radix-ui/react-arrow": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-layout-effect": "1.1.1", + "@radix-ui/react-use-rect": "1.1.1", + "@radix-ui/react-use-size": "1.1.1", + "@radix-ui/rect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-portal": { + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz", + "integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-visually-hidden": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.3.tgz", + "integrity": "sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-use-callback-ref": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.1.1.tgz", @@ -2295,6 +2648,39 @@ "tailwindcss": "4.1.5" } }, + "node_modules/@tanstack/react-table": { + "version": "8.21.3", + "resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.21.3.tgz", + "integrity": "sha512-5nNMTSETP4ykGegmVkhjcS8tTLW6Vl4axfEGQN3v0zdHYbK4UfoqfPChclTrJ4EoK9QynqAu9oUf8VEmrpZ5Ww==", + "license": "MIT", + "dependencies": { + "@tanstack/table-core": "8.21.3" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + }, + "peerDependencies": { + "react": ">=16.8", + "react-dom": ">=16.8" + } + }, + "node_modules/@tanstack/table-core": { + "version": "8.21.3", + "resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.21.3.tgz", + "integrity": "sha512-ldZXEhOBb8Is7xLs01fR3YEc3DERiz5silj8tnGkFZytt1abEvl/GhUmCE0PMLaMPTa3Jk4HbKmRlHmu+gCftg==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/tannerlinsley" + } + }, "node_modules/@tybys/wasm-util": { "version": "0.9.0", "resolved": "https://registry.npmjs.org/@tybys/wasm-util/-/wasm-util-0.9.0.tgz", @@ -4763,6 +5149,16 @@ "node": ">=0.8.19" } }, + "node_modules/input-otp": { + "version": "1.4.2", + "resolved": "https://registry.npmjs.org/input-otp/-/input-otp-1.4.2.tgz", + "integrity": "sha512-l3jWwYNvrEa6NTCt7BECfCm48GvwuZzkoeG3gBL2w4CHeOXW3eKFmf9UNYkNfYc3mxMrthMnxjIE07MT0zLBQA==", + "license": "MIT", + "peerDependencies": { + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0.0 || ^19.0.0-rc" + } + }, "node_modules/internal-slot": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.1.0.tgz", diff --git a/web/package.json b/web/package.json index 7a522b1c..1089bf75 100644 --- a/web/package.json +++ b/web/package.json @@ -23,6 +23,7 @@ "@hookform/resolvers": "^5.0.1", "@radix-ui/react-checkbox": "^1.3.1", "@radix-ui/react-dialog": "^1.1.14", + "@radix-ui/react-dropdown-menu": "^2.1.15", "@radix-ui/react-hover-card": "^1.1.13", "@radix-ui/react-label": "^2.1.6", "@radix-ui/react-popover": "^1.1.14", @@ -36,6 +37,7 @@ "@radix-ui/react-toggle-group": "^1.1.9", "@radix-ui/react-tooltip": "^1.2.7", "@tailwindcss/postcss": "^4.1.5", + "@tanstack/react-table": "^8.21.3", "axios": "^1.8.4", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", diff --git a/web/src/app/home/bots/BotDetailDialog.tsx b/web/src/app/home/bots/BotDetailDialog.tsx index 1c4a2403..db19e1d4 100644 --- a/web/src/app/home/bots/BotDetailDialog.tsx +++ b/web/src/app/home/bots/BotDetailDialog.tsx @@ -127,10 +127,8 @@ export default function BotDetailDialog({ @@ -199,10 +197,8 @@ export default function BotDetailDialog({ )} {activeMenu === 'logs' && botId && ( diff --git a/web/src/app/home/bots/components/bot-form/BotForm.tsx b/web/src/app/home/bots/components/bot-form/BotForm.tsx index 40a902c2..bd757ae0 100644 --- a/web/src/app/home/bots/components/bot-form/BotForm.tsx +++ b/web/src/app/home/bots/components/bot-form/BotForm.tsx @@ -64,17 +64,13 @@ const getFormSchema = (t: (key: string) => string) => export default function BotForm({ initBotId, onFormSubmit, - onFormCancel, onBotDeleted, onNewBotCreated, - hideButtons = false, }: { initBotId?: string; onFormSubmit: (value: z.infer>) => void; - onFormCancel: () => void; onBotDeleted: () => void; onNewBotCreated: (botId: string) => void; - hideButtons?: boolean; }) { const { t } = useTranslation(); const formSchema = getFormSchema(t); @@ -214,6 +210,7 @@ export default function BotForm({ }); setAdapterNameToDynamicConfigMap(adapterNameToDynamicConfigMap); } + async function getBotConfig( botId: string, ): Promise> { @@ -527,45 +524,6 @@ export default function BotForm({ )} - - {!hideButtons && ( -
-
- {!initBotId && ( - - )} - {initBotId && ( - <> - - - - )} - -
-
- )} diff --git a/web/src/app/home/bots/page.tsx b/web/src/app/home/bots/page.tsx index d4305898..ad130fae 100644 --- a/web/src/app/home/bots/page.tsx +++ b/web/src/app/home/bots/page.tsx @@ -92,7 +92,7 @@ export default function BotConfigPage() { } return ( -
+
; }) { const [llmModels, setLlmModels] = useState([]); + const [knowledgeBases, setKnowledgeBases] = useState([]); const { t } = useTranslation(); useEffect(() => { @@ -50,6 +52,19 @@ export default function DynamicFormItemComponent({ } }, [config.type]); + useEffect(() => { + if (config.type === DynamicFormItemType.KNOWLEDGE_BASE_SELECTOR) { + httpClient + .getKnowledgeBases() + .then((resp) => { + setKnowledgeBases(resp.bases); + }) + .catch((err) => { + toast.error('获取知识库列表失败:' + err.message); + }); + } + }, [config.type]); + switch (config.type) { case DynamicFormItemType.INT: case DynamicFormItemType.FLOAT: @@ -249,6 +264,25 @@ export default function DynamicFormItemComponent({ ); + case DynamicFormItemType.KNOWLEDGE_BASE_SELECTOR: + return ( + + ); + case DynamicFormItemType.PROMPT_EDITOR: return (
diff --git a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx index e21317d6..b3edb98a 100644 --- a/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx +++ b/web/src/app/home/components/home-sidebar/sidbarConfigList.tsx @@ -47,6 +47,7 @@ export const sidebarConfigList = [ zh_Hans: 'https://docs.langbot.app/zh/deploy/models/readme.html', }, }), + new SidebarChildVO({ id: 'pipelines', name: t('pipelines.title'), @@ -67,6 +68,25 @@ export const sidebarConfigList = [ zh_Hans: 'https://docs.langbot.app/zh/deploy/pipelines/readme.html', }, }), + new SidebarChildVO({ + id: 'knowledge', + name: t('knowledge.title'), + icon: ( + + + + ), + route: '/home/knowledge', + description: t('knowledge.description'), + helpLink: { + en_US: 'https://docs.langbot.app/en/deploy/knowledge/readme.html', + zh_Hans: 'https://docs.langbot.app/zh/deploy/knowledge/readme.html', + }, + }), new SidebarChildVO({ id: 'plugins', name: t('plugins.title'), diff --git a/web/src/app/home/knowledge/KBDetailDialog.tsx b/web/src/app/home/knowledge/KBDetailDialog.tsx new file mode 100644 index 00000000..262d872f --- /dev/null +++ b/web/src/app/home/knowledge/KBDetailDialog.tsx @@ -0,0 +1,236 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogFooter, +} from '@/components/ui/dialog'; +import { + Sidebar, + SidebarContent, + SidebarGroup, + SidebarGroupContent, + SidebarMenu, + SidebarMenuButton, + SidebarMenuItem, + SidebarProvider, +} from '@/components/ui/sidebar'; +import { Button } from '@/components/ui/button'; +import { useTranslation } from 'react-i18next'; +import { httpClient } from '@/app/infra/http/HttpClient'; +// import { KnowledgeBase } from '@/app/infra/entities/api'; +import KBForm from '@/app/home/knowledge/components/kb-form/KBForm'; +import KBDoc from '@/app/home/knowledge/components/kb-docs/KBDoc'; + +interface KBDetailDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + kbId?: string; + onFormCancel: () => void; + onKbDeleted: () => void; + onNewKbCreated: (kbId: string) => void; + onKbUpdated: (kbId: string) => void; +} + +export default function KBDetailDialog({ + open, + onOpenChange, + kbId: propKbId, + onFormCancel, + onKbDeleted, + onNewKbCreated, + onKbUpdated, +}: KBDetailDialogProps) { + const { t } = useTranslation(); + const [kbId, setKbId] = useState(propKbId); + const [activeMenu, setActiveMenu] = useState('metadata'); + const [showDeleteConfirm, setShowDeleteConfirm] = useState(false); + + useEffect(() => { + setKbId(propKbId); + setActiveMenu('metadata'); + }, [propKbId, open]); + + const menu = [ + { + key: 'metadata', + label: t('knowledge.metadata'), + icon: ( + + + + ), + }, + { + key: 'documents', + label: t('knowledge.documents'), + icon: ( + + + + ), + }, + ]; + + const confirmDelete = () => { + httpClient.deleteKnowledgeBase(kbId ?? '').then(() => { + onKbDeleted(); + }); + setShowDeleteConfirm(false); + }; + + if (!kbId) { + // new kb + return ( + + +
+ + {t('knowledge.createKnowledgeBase')} + +
+ {activeMenu === 'metadata' && ( + + )} + {activeMenu === 'documents' &&
documents
} +
+ {activeMenu === 'metadata' && ( + +
+ + +
+
+ )} +
+
+
+ ); + } + + return ( + <> + + + + + + + + + {menu.map((item) => ( + + setActiveMenu(item.key)} + > + + {item.icon} + {item.label} + + + + ))} + + + + + +
+ + + {activeMenu === 'metadata' + ? t('knowledge.editKnowledgeBase') + : t('knowledge.editDocument')} + + +
+ {activeMenu === 'metadata' && ( + + )} + {activeMenu === 'documents' && } +
+ {activeMenu === 'metadata' && ( + +
+ + + +
+
+ )} +
+
+
+
+ + {/* 删除确认对话框 */} + + + + {t('common.confirmDelete')} + +
+ {t('knowledge.deleteKnowledgeBaseConfirmation')} +
+ + + + +
+
+ + ); +} diff --git a/web/src/app/home/knowledge/components/kb-card/KBCard.module.css b/web/src/app/home/knowledge/components/kb-card/KBCard.module.css new file mode 100644 index 00000000..2ecbd44a --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-card/KBCard.module.css @@ -0,0 +1,107 @@ +.cardContainer { + width: 100%; + height: 10rem; + background-color: #fff; + border-radius: 10px; + box-shadow: 0px 2px 2px 0 rgba(0, 0, 0, 0.2); + padding: 1.2rem; + cursor: pointer; + display: flex; + flex-direction: row; + justify-content: space-between; + gap: 0.5rem; +} + +.cardContainer:hover { + box-shadow: 0px 2px 8px 0 rgba(0, 0, 0, 0.1); +} + +.basicInfoContainer { + width: 100%; + height: 100%; + display: flex; + flex-direction: column; + justify-content: space-between; + gap: 0.4rem; + min-width: 0; +} + +.basicInfoNameContainer { + display: flex; + flex-direction: column; + gap: 0.2rem; +} + +.basicInfoNameText { + font-size: 1.4rem; + font-weight: 500; +} + +.basicInfoDescriptionText { + font-size: 0.9rem; + font-weight: 400; + display: -webkit-box; + -webkit-line-clamp: 3; + -webkit-box-orient: vertical; + overflow: hidden; + text-overflow: ellipsis; + color: #b1b1b1; +} + +.basicInfoLastUpdatedTimeContainer { + display: flex; + flex-direction: row; + align-items: center; + gap: 0.5rem; +} + +.basicInfoUpdateTimeIcon { + width: 1.2rem; + height: 1.2rem; +} + +.basicInfoUpdateTimeText { + font-size: 1rem; + font-weight: 400; +} + +.operationContainer { + display: flex; + flex-direction: column; + align-items: flex-end; + justify-content: space-between; + gap: 0.5rem; + width: 8rem; +} + +.operationDefaultBadge { + display: flex; + flex-direction: row; + gap: 0.5rem; +} + +.operationDefaultBadgeIcon { + width: 1.2rem; + height: 1.2rem; + color: #ffcd27; +} + +.operationDefaultBadgeText { + font-size: 1rem; + font-weight: 400; + color: #ffcd27; +} + +.bigText { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + font-size: 1.4rem; + font-weight: bold; + max-width: 100%; +} + +.debugButtonIcon { + width: 1.2rem; + height: 1.2rem; +} diff --git a/web/src/app/home/knowledge/components/kb-card/KBCard.tsx b/web/src/app/home/knowledge/components/kb-card/KBCard.tsx new file mode 100644 index 00000000..560b0497 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-card/KBCard.tsx @@ -0,0 +1,36 @@ +import { KnowledgeBaseVO } from '@/app/home/knowledge/components/kb-card/KBCardVO'; +import { useTranslation } from 'react-i18next'; +import styles from './KBCard.module.css'; + +export default function KBCard({ kbCardVO }: { kbCardVO: KnowledgeBaseVO }) { + const { t } = useTranslation(); + return ( +
+
+
+
+ {kbCardVO.name} +
+
+ {kbCardVO.description} +
+
+ +
+ + + +
+ {t('knowledge.updateTime')} + {kbCardVO.lastUpdatedTimeAgo} +
+
+
+
+ ); +} diff --git a/web/src/app/home/knowledge/components/kb-card/KBCardVO.ts b/web/src/app/home/knowledge/components/kb-card/KBCardVO.ts new file mode 100644 index 00000000..bfbc2adb --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-card/KBCardVO.ts @@ -0,0 +1,23 @@ +export interface IKnowledgeBaseVO { + id: string; + name: string; + description: string; + embeddingModelUUID: string; + lastUpdatedTimeAgo: string; +} + +export class KnowledgeBaseVO implements IKnowledgeBaseVO { + id: string; + name: string; + description: string; + embeddingModelUUID: string; + lastUpdatedTimeAgo: string; + + constructor(props: IKnowledgeBaseVO) { + this.id = props.id; + this.name = props.name; + this.description = props.description; + this.embeddingModelUUID = props.embeddingModelUUID; + this.lastUpdatedTimeAgo = props.lastUpdatedTimeAgo; + } +} diff --git a/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx b/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx new file mode 100644 index 00000000..3b4123ec --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-docs/FileUploadZone.tsx @@ -0,0 +1,145 @@ +import React, { useCallback, useState } from 'react'; +import { Card, CardContent } from '@/components/ui/card'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { toast } from 'sonner'; +import { useTranslation } from 'react-i18next'; + +interface FileUploadZoneProps { + kbId: string; + onUploadSuccess: () => void; + onUploadError: (error: string) => void; +} + +export default function FileUploadZone({ + kbId, + onUploadSuccess, + onUploadError, +}: FileUploadZoneProps) { + const { t } = useTranslation(); + const [isDragOver, setIsDragOver] = useState(false); + const [isUploading, setIsUploading] = useState(false); + + const handleUpload = useCallback( + async (file: File) => { + if (isUploading) return; + + setIsUploading(true); + const toastId = toast.loading(t('knowledge.documentsTab.uploadingFile')); + + try { + // Step 1: Upload file to server + const uploadResult = await httpClient.uploadDocumentFile(file); + + // Step 2: Associate file with knowledge base + await httpClient.uploadKnowledgeBaseFile(kbId, uploadResult.file_id); + + toast.success(t('knowledge.documentsTab.uploadSuccess'), { + id: toastId, + }); + onUploadSuccess(); + } catch (error) { + console.error('File upload failed:', error); + const errorMessage = t('knowledge.documentsTab.uploadError'); + toast.error(errorMessage, { id: toastId }); + onUploadError(errorMessage); + } finally { + setIsUploading(false); + } + }, + [kbId, isUploading, onUploadSuccess, onUploadError], + ); + + const handleDragOver = useCallback((e: React.DragEvent) => { + e.preventDefault(); + setIsDragOver(true); + }, []); + + const handleDragLeave = useCallback((e: React.DragEvent) => { + e.preventDefault(); + setIsDragOver(false); + }, []); + + const handleDrop = useCallback( + (e: React.DragEvent) => { + e.preventDefault(); + setIsDragOver(false); + + const files = Array.from(e.dataTransfer.files); + if (files.length > 0) { + handleUpload(files[0]); + } + }, + [handleUpload], + ); + + const handleFileSelect = useCallback( + (e: React.ChangeEvent) => { + const files = e.target.files; + if (files && files.length > 0) { + handleUpload(files[0]); + } + }, + [handleUpload], + ); + + return ( + + +
+ + + +
+
+
+ ); +} diff --git a/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx b/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx new file mode 100644 index 00000000..fb94dace --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-docs/KBDoc.tsx @@ -0,0 +1,72 @@ +import { useEffect, useState } from 'react'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { KnowledgeBaseFile } from '@/app/infra/entities/api'; +import { columns, DocumentFile } from './documents/columns'; +import { DataTable } from './documents/data-table'; +import FileUploadZone from './FileUploadZone'; +import { toast } from 'sonner'; +import { useTranslation } from 'react-i18next'; + +export default function KBDoc({ kbId }: { kbId: string }) { + const [documentsList, setDocumentsList] = useState([]); + const { t } = useTranslation(); + + useEffect(() => { + getDocumentsList(); + + const intervalId = setInterval(() => { + getDocumentsList(); + }, 5000); + + return () => { + clearInterval(intervalId); + }; + }, [kbId]); + + async function getDocumentsList() { + const resp = await httpClient.getKnowledgeBaseFiles(kbId); + setDocumentsList( + resp.files.map((file: KnowledgeBaseFile) => { + return { + uuid: file.uuid, + name: file.file_name, + status: file.status, + }; + }), + ); + } + + const handleUploadSuccess = () => { + // Refresh document list after successful upload + getDocumentsList(); + }; + + const handleUploadError = (error: string) => { + // Error messages are already handled by toast in FileUploadZone component + console.error('Upload failed:', error); + }; + + const handleDelete = (id: string) => { + httpClient + .deleteKnowledgeBaseFile(kbId, id) + .then(() => { + getDocumentsList(); + toast.success(t('knowledge.documentsTab.fileDeleteSuccess')); + }) + .catch((error) => { + console.error('Delete failed:', error); + toast.error(t('knowledge.documentsTab.fileDeleteFailed')); + }); + }; + + return ( +
+ + +
+ ); +} diff --git a/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx b/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx new file mode 100644 index 00000000..6142cfc4 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-docs/documents/columns.tsx @@ -0,0 +1,94 @@ +'use client'; + +import { ColumnDef } from '@tanstack/react-table'; +import { MoreHorizontal } from 'lucide-react'; +import { Button } from '@/components/ui/button'; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuTrigger, +} from '@/components/ui/dropdown-menu'; +import { Badge } from '@/components/ui/badge'; +import { TFunction } from 'i18next'; + +export type DocumentFile = { + uuid: string; + name: string; + status: string; +}; + +export const columns = ( + onDelete: (id: string) => void, + t: TFunction, +): ColumnDef[] => { + return [ + { + accessorKey: 'name', + header: t('knowledge.documentsTab.name'), + }, + { + accessorKey: 'status', + header: t('knowledge.documentsTab.status'), + cell: ({ row }) => { + const document = row.original; + + switch (document.status) { + case 'processing': + return ( + + {t('knowledge.documentsTab.processing')} + + ); + case 'completed': + return ( + + {t('knowledge.documentsTab.completed')} + + ); + case 'failed': + return ( + + {t('knowledge.documentsTab.failed')} + + ); + default: + return ( + + {document.status} + + ); + } + }, + }, + { + id: 'actions', + cell: ({ row }) => { + const document = row.original; + + return ( + + + + + + + {t('knowledge.documentsTab.actions')} + + + onDelete(document.uuid)}> + {t('knowledge.documentsTab.delete')} + + + + ); + }, + }, + ]; +}; diff --git a/web/src/app/home/knowledge/components/kb-docs/documents/data-table.tsx b/web/src/app/home/knowledge/components/kb-docs/documents/data-table.tsx new file mode 100644 index 00000000..178ccad9 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-docs/documents/data-table.tsx @@ -0,0 +1,81 @@ +'use client'; + +import { + ColumnDef, + flexRender, + getCoreRowModel, + useReactTable, +} from '@tanstack/react-table'; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table'; +import { useTranslation } from 'react-i18next'; + +interface DataTableProps { + columns: ColumnDef[]; + data: TData[]; +} + +export function DataTable({ + columns, + data, +}: DataTableProps) { + const { t } = useTranslation(); + const table = useReactTable({ + data, + columns, + getCoreRowModel: getCoreRowModel(), + }); + + return ( +
+ + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => { + return ( + + {header.isPlaceholder + ? null + : flexRender( + header.column.columnDef.header, + header.getContext(), + )} + + ); + })} + + ))} + + + {table.getRowModel().rows?.length ? ( + table.getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ))} + + )) + ) : ( + + + {t('knowledge.documentsTab.noResults')} + + + )} + +
+
+ ); +} diff --git a/web/src/app/home/knowledge/components/kb-form/ChooseEntity.ts b/web/src/app/home/knowledge/components/kb-form/ChooseEntity.ts new file mode 100644 index 00000000..54f983e4 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-form/ChooseEntity.ts @@ -0,0 +1,4 @@ +export interface IEmbeddingModelEntity { + label: string; + value: string; +} diff --git a/web/src/app/home/knowledge/components/kb-form/KBForm.tsx b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx new file mode 100644 index 00000000..54d5d6e4 --- /dev/null +++ b/web/src/app/home/knowledge/components/kb-form/KBForm.tsx @@ -0,0 +1,234 @@ +import { useEffect, useState } from 'react'; +import { useForm } from 'react-hook-form'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { z } from 'zod'; +import { useTranslation } from 'react-i18next'; +import { Input } from '@/components/ui/input'; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, + FormDescription, +} from '@/components/ui/form'; +import { IEmbeddingModelEntity } from './ChooseEntity'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { KnowledgeBase } from '@/app/infra/entities/api'; +import { toast } from 'sonner'; + +const getFormSchema = (t: (key: string) => string) => + z.object({ + name: z.string().min(1, { message: t('knowledge.kbNameRequired') }), + description: z + .string() + .min(1, { message: t('knowledge.kbDescriptionRequired') }), + embeddingModelUUID: z + .string() + .min(1, { message: t('knowledge.embeddingModelUUIDRequired') }), + }); + +export default function KBForm({ + initKbId, + onNewKbCreated, + onKbUpdated, +}: { + initKbId?: string; + onNewKbCreated: (kbId: string) => void; + onKbUpdated: (kbId: string) => void; +}) { + const { t } = useTranslation(); + const formSchema = getFormSchema(t); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + name: '', + description: t('knowledge.defaultDescription'), + embeddingModelUUID: '', + }, + }); + + const [embeddingModelNameList, setEmbeddingModelNameList] = useState< + IEmbeddingModelEntity[] + >([]); + + useEffect(() => { + getEmbeddingModelNameList().then(() => { + if (initKbId) { + getKbConfig(initKbId).then((val) => { + form.setValue('name', val.name); + form.setValue('description', val.description); + form.setValue('embeddingModelUUID', val.embeddingModelUUID); + }); + } + }); + }, []); + + const getKbConfig = async ( + kbId: string, + ): Promise> => { + return new Promise((resolve) => { + httpClient.getKnowledgeBase(kbId).then((res) => { + resolve({ + name: res.base.name, + description: res.base.description, + embeddingModelUUID: res.base.embedding_model_uuid, + }); + }); + }); + }; + + const getEmbeddingModelNameList = async () => { + const resp = await httpClient.getProviderEmbeddingModels(); + setEmbeddingModelNameList( + resp.models.map((item) => { + return { + label: item.name, + value: item.uuid, + }; + }), + ); + }; + + const onSubmit = (data: z.infer) => { + console.log('data', data); + + if (initKbId) { + // update knowledge base + const updateKb: KnowledgeBase = { + name: data.name, + description: data.description, + embedding_model_uuid: data.embeddingModelUUID, + }; + httpClient + .updateKnowledgeBase(initKbId, updateKb) + .then((res) => { + console.log('update knowledge base success', res); + onKbUpdated(res.uuid); + toast.success(t('knowledge.updateKnowledgeBaseSuccess')); + }) + .catch((err) => { + console.error('update knowledge base failed', err); + toast.error(t('knowledge.updateKnowledgeBaseFailed')); + }); + } else { + // create knowledge base + const newKb: KnowledgeBase = { + name: data.name, + description: data.description, + embedding_model_uuid: data.embeddingModelUUID, + }; + httpClient + .createKnowledgeBase(newKb) + .then((res) => { + console.log('create knowledge base success', res); + onNewKbCreated(res.uuid); + }) + .catch((err) => { + console.error('create knowledge base failed', err); + }); + } + }; + + return ( + <> +
+ +
+ ( + + + {t('knowledge.kbName')} + * + + + + + + + )} + /> + ( + + + {t('knowledge.kbDescription')} + * + + + + + + + )} + /> + ( + + + {t('knowledge.embeddingModelUUID')} + * + + +
+ +
+
+ + {initKbId + ? t('knowledge.cannotChangeEmbeddingModel') + : t('knowledge.embeddingModelDescription')} + + +
+ )} + /> +
+
+ + + ); +} diff --git a/web/src/app/home/knowledge/knowledgeBase.module.css b/web/src/app/home/knowledge/knowledgeBase.module.css new file mode 100644 index 00000000..e811b521 --- /dev/null +++ b/web/src/app/home/knowledge/knowledgeBase.module.css @@ -0,0 +1,15 @@ +.configPageContainer { + width: 100%; + height: 100%; +} + +.knowledgeListContainer { + width: 100%; + padding-left: 0.8rem; + padding-right: 0.8rem; + display: grid; + grid-template-columns: repeat(auto-fill, minmax(24rem, 1fr)); + gap: 2rem; + justify-items: stretch; + align-items: start; +} diff --git a/web/src/app/home/knowledge/page.tsx b/web/src/app/home/knowledge/page.tsx new file mode 100644 index 00000000..0a8cc2eb --- /dev/null +++ b/web/src/app/home/knowledge/page.tsx @@ -0,0 +1,114 @@ +'use client'; + +import CreateCardComponent from '@/app/infra/basic-component/create-card-component/CreateCardComponent'; +import styles from './knowledgeBase.module.css'; +import { useTranslation } from 'react-i18next'; +import { useEffect, useState } from 'react'; +import { KnowledgeBaseVO } from '@/app/home/knowledge/components/kb-card/KBCardVO'; +import KBCard from '@/app/home/knowledge/components/kb-card/KBCard'; +import KBDetailDialog from '@/app/home/knowledge/KBDetailDialog'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { KnowledgeBase } from '@/app/infra/entities/api'; + +export default function KnowledgePage() { + const { t } = useTranslation(); + const [knowledgeBaseList, setKnowledgeBaseList] = useState( + [], + ); + const [selectedKbId, setSelectedKbId] = useState(''); + const [detailDialogOpen, setDetailDialogOpen] = useState(false); + + useEffect(() => { + getKnowledgeBaseList(); + }, []); + + async function getKnowledgeBaseList() { + const resp = await httpClient.getKnowledgeBases(); + setKnowledgeBaseList( + resp.bases.map((kb: KnowledgeBase) => { + const currentTime = new Date(); + const lastUpdatedTimeAgo = Math.floor( + (currentTime.getTime() - + new Date(kb.updated_at ?? currentTime.getTime()).getTime()) / + 1000 / + 60 / + 60 / + 24, + ); + + const lastUpdatedTimeAgoText = + lastUpdatedTimeAgo > 0 + ? ` ${lastUpdatedTimeAgo} ${t('knowledge.daysAgo')}` + : t('knowledge.today'); + + return new KnowledgeBaseVO({ + id: kb.uuid || '', + name: kb.name, + description: kb.description, + embeddingModelUUID: kb.embedding_model_uuid, + lastUpdatedTimeAgo: lastUpdatedTimeAgoText, + }); + }), + ); + } + + const handleKBCardClick = (kbId: string) => { + setSelectedKbId(kbId); + setDetailDialogOpen(true); + }; + + const handleCreateKBClick = () => { + setSelectedKbId(''); + setDetailDialogOpen(true); + }; + + const handleFormCancel = () => { + setDetailDialogOpen(false); + }; + + const handleKbDeleted = () => { + getKnowledgeBaseList(); + setDetailDialogOpen(false); + }; + + const handleNewKbCreated = (newKbId: string) => { + getKnowledgeBaseList(); + setSelectedKbId(newKbId); + setDetailDialogOpen(true); + }; + + const handleKbUpdated = () => { + getKnowledgeBaseList(); + }; + + return ( +
+ + +
+ + + {knowledgeBaseList.map((kb) => { + return ( +
handleKBCardClick(kb.id)}> + +
+ ); + })} +
+
+ ); +} diff --git a/web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts b/web/src/app/home/models/component/ChooseRequesterEntity.ts similarity index 100% rename from web/src/app/home/models/component/llm-form/ChooseRequesterEntity.ts rename to web/src/app/home/models/component/ChooseRequesterEntity.ts diff --git a/web/src/app/home/models/component/ICreateEmbeddingField.ts b/web/src/app/home/models/component/ICreateEmbeddingField.ts new file mode 100644 index 00000000..ea198f3f --- /dev/null +++ b/web/src/app/home/models/component/ICreateEmbeddingField.ts @@ -0,0 +1,7 @@ +export interface ICreateEmbeddingField { + name: string; + model_provider: string; + url: string; + api_key: string; + extra_args?: string[]; +} diff --git a/web/src/app/home/models/ICreateLLMField.ts b/web/src/app/home/models/component/ICreateLLMField.ts similarity index 100% rename from web/src/app/home/models/ICreateLLMField.ts rename to web/src/app/home/models/component/ICreateLLMField.ts diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css new file mode 100644 index 00000000..9c6c54f7 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.module.css @@ -0,0 +1,97 @@ +.cardContainer { + width: 100%; + height: 10rem; + background-color: #fff; + border-radius: 10px; + box-shadow: 0px 2px 2px 0 rgba(0, 0, 0, 0.2); + padding: 1.2rem; + cursor: pointer; +} + +.cardContainer:hover { + box-shadow: 0px 2px 8px 0 rgba(0, 0, 0, 0.1); +} + +.iconBasicInfoContainer { + width: 100%; + height: 100%; + display: flex; + flex-direction: row; + gap: 0.8rem; + user-select: none; +} + +.iconImage { + width: 3.8rem; + height: 3.8rem; + margin: 0.2rem; + border-radius: 50%; +} + +.basicInfoContainer { + display: flex; + flex-direction: column; + gap: 0.2rem; + min-width: 0; + width: 100%; +} + +.basicInfoText { + font-size: 1.4rem; + font-weight: bold; +} + +.providerContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; +} + +.providerIcon { + width: 1.2rem; + height: 1.2rem; + margin-top: 0.2rem; + color: #626262; +} + +.providerLabel { + font-size: 1.2rem; + font-weight: 600; + color: #626262; +} + +.baseURLContainer { + display: flex; + flex-direction: row; + align-items: center; + justify-content: flex-start; + gap: 0.2rem; + width: calc(100% - 3rem); +} + +.baseURLIcon { + width: 1.2rem; + height: 1.2rem; + color: #626262; +} + +.baseURLText { + font-size: 1rem; + width: 100%; + color: #626262; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + max-width: 100%; +} + +.bigText { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + font-size: 1.4rem; + font-weight: bold; + max-width: 100%; +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx new file mode 100644 index 00000000..e3dfaf80 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCard.tsx @@ -0,0 +1,53 @@ +import styles from './EmbeddingCard.module.css'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; + +export default function EmbeddingCard({ cardVO }: { cardVO: EmbeddingCardVO }) { + return ( +
+
+ icon + +
+ {/* 名称 */} +
+ {cardVO.name} +
+ {/* 厂商 */} +
+ + + + + {cardVO.providerLabel} + +
+ {/* baseURL */} +
+ + + + {cardVO.baseURL} +
+
+
+
+ ); +} diff --git a/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts new file mode 100644 index 00000000..f6d960f6 --- /dev/null +++ b/web/src/app/home/models/component/embedding-card/EmbeddingCardVO.ts @@ -0,0 +1,23 @@ +export interface IEmbeddingCardVO { + id: string; + iconURL: string; + name: string; + providerLabel: string; + baseURL: string; +} + +export class EmbeddingCardVO implements IEmbeddingCardVO { + id: string; + iconURL: string; + providerLabel: string; + name: string; + baseURL: string; + + constructor(props: IEmbeddingCardVO) { + this.id = props.id; + this.iconURL = props.iconURL; + this.providerLabel = props.providerLabel; + this.name = props.name; + this.baseURL = props.baseURL; + } +} diff --git a/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx new file mode 100644 index 00000000..4658a22f --- /dev/null +++ b/web/src/app/home/models/component/embedding-form/EmbeddingForm.tsx @@ -0,0 +1,563 @@ +import { ICreateEmbeddingField } from '@/app/home/models/component/ICreateEmbeddingField'; +import { useEffect, useState } from 'react'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { EmbeddingModel } from '@/app/infra/entities/api'; +import { UUID } from 'uuidjs'; + +import { zodResolver } from '@hookform/resolvers/zod'; +import { useForm } from 'react-hook-form'; +import { z } from 'zod'; +import { useTranslation } from 'react-i18next'; + +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogFooter, +} from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { toast } from 'sonner'; +import { i18nObj } from '@/i18n/I18nProvider'; + +const getExtraArgSchema = (t: (key: string) => string) => + z + .object({ + key: z.string().min(1, { message: t('models.keyNameRequired') }), + type: z.enum(['string', 'number', 'boolean']), + value: z.string(), + }) + .superRefine((data, ctx) => { + if (data.type === 'number' && isNaN(Number(data.value))) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeValidNumber'), + path: ['value'], + }); + } + if ( + data.type === 'boolean' && + data.value !== 'true' && + data.value !== 'false' + ) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: t('models.mustBeTrueOrFalse'), + path: ['value'], + }); + } + }); + +const getFormSchema = (t: (key: string) => string) => + z.object({ + name: z.string().min(1, { message: t('models.modelNameRequired') }), + model_provider: z + .string() + .min(1, { message: t('models.modelProviderRequired') }), + url: z.string().min(1, { message: t('models.requestURLRequired') }), + api_key: z.string().min(1, { message: t('models.apiKeyRequired') }), + extra_args: z.array(getExtraArgSchema(t)).optional(), + }); + +export default function EmbeddingForm({ + editMode, + initEmbeddingId, + onFormSubmit, + onFormCancel, + onEmbeddingDeleted, +}: { + editMode: boolean; + initEmbeddingId?: string; + onFormSubmit: () => void; + onFormCancel: () => void; + onEmbeddingDeleted: () => void; +}) { + const { t } = useTranslation(); + const formSchema = getFormSchema(t); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + name: '', + model_provider: '', + url: '', + api_key: 'sk-xxxxx', + extra_args: [], + }, + }); + + const [extraArgs, setExtraArgs] = useState< + { key: string; type: 'string' | 'number' | 'boolean'; value: string }[] + >([]); + + const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false); + const [requesterNameList, setRequesterNameList] = useState< + IChooseRequesterEntity[] + >([]); + const [requesterDefaultURLList, setRequesterDefaultURLList] = useState< + string[] + >([]); + const [modelTesting, setModelTesting] = useState(false); + + useEffect(() => { + initEmbeddingModelFormComponent().then(() => { + if (editMode && initEmbeddingId) { + getEmbeddingConfig(initEmbeddingId).then((val) => { + form.setValue('name', val.name); + form.setValue('model_provider', val.model_provider); + // setCurrentModelProvider(val.model_provider); + form.setValue('url', val.url); + form.setValue('api_key', val.api_key); + if (val.extra_args) { + const args = val.extra_args.map((arg) => { + const [key, value] = arg.split(':'); + let type: 'string' | 'number' | 'boolean' = 'string'; + if (!isNaN(Number(value))) { + type = 'number'; + } else if (value === 'true' || value === 'false') { + type = 'boolean'; + } + return { + key, + type, + value, + }; + }); + setExtraArgs(args); + form.setValue('extra_args', args); + } + }); + } else { + form.reset(); + } + }); + }, []); + + const addExtraArg = () => { + setExtraArgs([...extraArgs, { key: '', type: 'string', value: '' }]); + }; + + const updateExtraArg = ( + index: number, + field: 'key' | 'type' | 'value', + value: string, + ) => { + const newArgs = [...extraArgs]; + newArgs[index] = { + ...newArgs[index], + [field]: value, + }; + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + const removeExtraArg = (index: number) => { + const newArgs = extraArgs.filter((_, i) => i !== index); + setExtraArgs(newArgs); + form.setValue('extra_args', newArgs); + }; + + async function initEmbeddingModelFormComponent() { + const requesterNameList = + await httpClient.getProviderRequesters('text-embedding'); + setRequesterNameList( + requesterNameList.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }), + ); + setRequesterDefaultURLList( + requesterNameList.requesters.map((item) => { + const config = item.spec.config; + for (let i = 0; i < config.length; i++) { + if (config[i].name == 'base_url') { + return config[i].default?.toString() || ''; + } + } + return ''; + }), + ); + } + + async function getEmbeddingConfig( + id: string, + ): Promise { + const embeddingModel = await httpClient.getProviderEmbeddingModel(id); + + const fakeExtraArgs = []; + const extraArgs = embeddingModel.model.extra_args as Record; + for (const key in extraArgs) { + fakeExtraArgs.push(`${key}:${extraArgs[key]}`); + } + return { + name: embeddingModel.model.name, + model_provider: embeddingModel.model.requester, + url: embeddingModel.model.requester_config?.base_url, + api_key: embeddingModel.model.api_keys[0], + extra_args: fakeExtraArgs, + }; + } + + function handleFormSubmit(value: z.infer) { + const extraArgsObj: Record = {}; + value.extra_args?.forEach( + (arg: { key: string; type: string; value: string }) => { + if (arg.type === 'number') { + extraArgsObj[arg.key] = Number(arg.value); + } else if (arg.type === 'boolean') { + extraArgsObj[arg.key] = arg.value === 'true'; + } else { + extraArgsObj[arg.key] = arg.value; + } + }, + ); + + const embeddingModel: EmbeddingModel = { + uuid: editMode ? initEmbeddingId || '' : UUID.generate(), + name: value.name, + description: '', + requester: value.model_provider, + requester_config: { + base_url: value.url, + timeout: 120, + }, + extra_args: extraArgsObj, + api_keys: [value.api_key], + }; + + if (editMode) { + onSaveEdit(embeddingModel).then(() => { + form.reset(); + }); + } else { + onCreateEmbedding(embeddingModel).then(() => { + form.reset(); + }); + } + } + + async function onCreateEmbedding(embeddingModel: EmbeddingModel) { + try { + await httpClient.createProviderEmbeddingModel(embeddingModel); + onFormSubmit(); + toast.success(t('models.createSuccess')); + } catch (err) { + toast.error(t('models.createError') + (err as Error).message); + } + } + + async function onSaveEdit(embeddingModel: EmbeddingModel) { + try { + await httpClient.updateProviderEmbeddingModel( + initEmbeddingId || '', + embeddingModel, + ); + onFormSubmit(); + toast.success(t('models.saveSuccess')); + } catch (err) { + toast.error(t('models.saveError') + (err as Error).message); + } + } + + function deleteModel() { + if (initEmbeddingId) { + httpClient + .deleteProviderEmbeddingModel(initEmbeddingId) + .then(() => { + onEmbeddingDeleted(); + toast.success(t('models.deleteSuccess')); + }) + .catch((err) => { + toast.error(t('models.deleteError') + err.message); + }); + } + } + + function testEmbeddingModelInForm() { + setModelTesting(true); + httpClient + .testEmbeddingModel('_', { + uuid: '', + name: form.getValues('name'), + description: '', + requester: form.getValues('model_provider'), + requester_config: { + base_url: form.getValues('url'), + timeout: 120, + }, + api_keys: [form.getValues('api_key')], + }) + .then((res) => { + console.log(res); + toast.success(t('models.testSuccess')); + }) + .catch(() => { + toast.error(t('models.testError')); + }) + .finally(() => { + setModelTesting(false); + }); + } + + return ( +
+ + + + {t('common.confirmDelete')} + + + {t('models.deleteConfirmation')} + + + + + + + + +
+ +
+ ( + + + {t('models.modelName')} + * + + + + + + + {t('models.modelProviderDescription')} + + + )} + /> + + ( + + + {t('models.modelProvider')} + * + + + + + + + )} + /> + + ( + + + {t('models.requestURL')} + * + + + + + + + )} + /> + + ( + + + {t('models.apiKey')} + * + + + + + + + )} + /> + + + {t('models.extraParameters')} +
+ {extraArgs.map((arg, index) => ( +
+ + updateExtraArg(index, 'key', e.target.value) + } + /> + + + updateExtraArg(index, 'value', e.target.value) + } + /> + +
+ ))} + +
+ + {t('embedding.extraParametersDescription')} + + +
+
+ + {editMode && ( + + )} + + + + + + + +
+ +
+ ); +} diff --git a/web/src/app/home/models/component/llm-form/LLMForm.tsx b/web/src/app/home/models/component/llm-form/LLMForm.tsx index f483f183..73cc32fe 100644 --- a/web/src/app/home/models/component/llm-form/LLMForm.tsx +++ b/web/src/app/home/models/component/llm-form/LLMForm.tsx @@ -1,6 +1,6 @@ -import { ICreateLLMField } from '@/app/home/models/ICreateLLMField'; +import { ICreateLLMField } from '@/app/home/models/component/ICreateLLMField'; import { useEffect, useState } from 'react'; -import { IChooseRequesterEntity } from '@/app/home/models/component/llm-form/ChooseRequesterEntity'; +import { IChooseRequesterEntity } from '@/app/home/models/component/ChooseRequesterEntity'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; import { UUID } from 'uuidjs'; @@ -197,7 +197,7 @@ export default function LLMForm({ }; async function initLLMModelFormComponent() { - const requesterNameList = await httpClient.getProviderRequesters(); + const requesterNameList = await httpClient.getProviderRequesters('llm'); setRequesterNameList( requesterNameList.requesters.map((item) => { return { @@ -596,7 +596,7 @@ export default function LLMForm({
- {t('models.extraParametersDescription')} + {t('llm.extraParametersDescription')} diff --git a/web/src/app/home/models/page.tsx b/web/src/app/home/models/page.tsx index 3ccec486..2f936753 100644 --- a/web/src/app/home/models/page.tsx +++ b/web/src/app/home/models/page.tsx @@ -8,6 +8,7 @@ import LLMForm from '@/app/home/models/component/llm-form/LLMForm'; import CreateCardComponent from '@/app/infra/basic-component/create-card-component/CreateCardComponent'; import { httpClient } from '@/app/infra/http/HttpClient'; import { LLMModel } from '@/app/infra/entities/api'; +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; import { Dialog, DialogContent, @@ -17,6 +18,9 @@ import { import { toast } from 'sonner'; import { useTranslation } from 'react-i18next'; import { i18nObj } from '@/i18n/I18nProvider'; +import { EmbeddingCardVO } from '@/app/home/models/component/embedding-card/EmbeddingCardVO'; +import EmbeddingCard from '@/app/home/models/component/embedding-card/EmbeddingCard'; +import EmbeddingForm from '@/app/home/models/component/embedding-form/EmbeddingForm'; export default function LLMConfigPage() { const { t } = useTranslation(); @@ -24,13 +28,21 @@ export default function LLMConfigPage() { const [modalOpen, setModalOpen] = useState(false); const [isEditForm, setIsEditForm] = useState(false); const [nowSelectedLLM, setNowSelectedLLM] = useState(null); + const [embeddingCardList, setEmbeddingCardList] = useState( + [], + ); + const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false); + const [isEditEmbeddingForm, setIsEditEmbeddingForm] = useState(false); + const [nowSelectedEmbedding, setNowSelectedEmbedding] = + useState(null); useEffect(() => { getLLMModelList(); + getEmbeddingModelList(); }, []); async function getLLMModelList() { - const requesterNameListResp = await httpClient.getProviderRequesters(); + const requesterNameListResp = await httpClient.getProviderRequesters('llm'); const requesterNameList = requesterNameListResp.requesters.map((item) => { return { label: i18nObj(item.label), @@ -74,6 +86,55 @@ export default function LLMConfigPage() { setNowSelectedLLM(null); setModalOpen(true); } + function selectEmbedding(cardVO: EmbeddingCardVO) { + setIsEditEmbeddingForm(true); + setNowSelectedEmbedding(cardVO); + setEmbeddingModalOpen(true); + } + + function handleCreateEmbeddingModelClick() { + setIsEditEmbeddingForm(false); + setNowSelectedEmbedding(null); + setEmbeddingModalOpen(true); + } + async function getEmbeddingModelList() { + const requesterNameListResp = + await httpClient.getProviderRequesters('text-embedding'); + const requesterNameList = requesterNameListResp.requesters.map((item) => { + return { + label: i18nObj(item.label), + value: item.name, + }; + }); + + httpClient + .getProviderEmbeddingModels() + .then((resp) => { + const embeddingModelList: EmbeddingCardVO[] = resp.models.map( + (model: { + uuid: string; + requester: string; + name: string; + requester_config?: { base_url?: string }; + }) => { + return new EmbeddingCardVO({ + id: model.uuid, + iconURL: httpClient.getProviderRequesterIconURL(model.requester), + name: model.name, + providerLabel: + requesterNameList.find((item) => item.value === model.requester) + ?.label || model.requester.substring(0, 10), + baseURL: model.requester_config?.base_url || '', + }); + }, + ); + setEmbeddingCardList(embeddingModelList); + }) + .catch((err) => { + console.error('get Embedding model list error', err); + toast.error(t('embedding.getModelListError') + err.message); + }); + } return (
@@ -101,26 +162,108 @@ export default function LLMConfigPage() { /> -
- - {cardList.map((cardVO) => { - return ( -
{ - selectLLM(cardVO); - }} - > - + + + + + {isEditEmbeddingForm + ? t('embedding.editModel') + : t('embedding.createModel')} + + + { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + onFormCancel={() => { + setEmbeddingModalOpen(false); + }} + onEmbeddingDeleted={() => { + setEmbeddingModalOpen(false); + getEmbeddingModelList(); + }} + /> + + + + +
+
+ + + {t('llm.llmModels')} + + + {t('embedding.embeddingModels')} + + +
+ +
+

{t('llm.description')}

- ); - })} -
+ + +
+

+ {t('embedding.description')} +

+
+
+
+ + +
+ + {cardList.map((cardVO) => { + return ( +
{ + selectLLM(cardVO); + }} + > + +
+ ); + })} +
+
+ + +
+ + {embeddingCardList.map((cardVO) => { + return ( +
{ + selectEmbedding(cardVO); + }} + > + +
+ ); + })} +
+
+
); } diff --git a/web/src/app/infra/entities/api/index.ts b/web/src/app/infra/entities/api/index.ts index d86a8be0..f16333b1 100644 --- a/web/src/app/infra/entities/api/index.ts +++ b/web/src/app/infra/entities/api/index.ts @@ -55,6 +55,38 @@ export interface LLMModel { // updated_at: string; } +export interface KnowledgeBase { + uuid?: string; + name: string; + description: string; + embedding_model_uuid: string; + created_at?: string; + top_k?: number; +} + +export interface ApiRespProviderEmbeddingModels { + models: EmbeddingModel[]; +} + +export interface ApiRespProviderEmbeddingModel { + model: EmbeddingModel; +} + +export interface EmbeddingModel { + name: string; + description: string; + uuid: string; + requester: string; + requester_config: { + base_url: string; + timeout: number; + }; + extra_args?: object; + api_keys: string[]; + // created_at: string; + // updated_at: string; +} + export interface ApiRespPipelines { pipelines: Pipeline[]; } @@ -110,6 +142,33 @@ export interface Bot { updated_at?: string; } +export interface ApiRespKnowledgeBases { + bases: KnowledgeBase[]; +} + +export interface ApiRespKnowledgeBase { + base: KnowledgeBase; +} + +export interface KnowledgeBase { + uuid?: string; + name: string; + description: string; + embedding_model_uuid: string; + created_at?: string; + updated_at?: string; +} + +export interface ApiRespKnowledgeBaseFiles { + files: KnowledgeBaseFile[]; +} + +export interface KnowledgeBaseFile { + uuid: string; + file_name: string; + status: string; +} + // plugins export interface ApiRespPlugins { plugins: Plugin[]; diff --git a/web/src/app/infra/entities/form/dynamic.ts b/web/src/app/infra/entities/form/dynamic.ts index 6a185c8b..6d6de096 100644 --- a/web/src/app/infra/entities/form/dynamic.ts +++ b/web/src/app/infra/entities/form/dynamic.ts @@ -21,6 +21,7 @@ export enum DynamicFormItemType { LLM_MODEL_SELECTOR = 'llm-model-selector', PROMPT_EDITOR = 'prompt-editor', UNKNOWN = 'unknown', + KNOWLEDGE_BASE_SELECTOR = 'knowledge-base-selector', } export interface IDynamicFormItemOption { diff --git a/web/src/app/infra/http/HttpClient.ts b/web/src/app/infra/http/HttpClient.ts index 5193703b..4967f66f 100644 --- a/web/src/app/infra/http/HttpClient.ts +++ b/web/src/app/infra/http/HttpClient.ts @@ -10,6 +10,9 @@ import { ApiRespProviderLLMModels, ApiRespProviderLLMModel, LLMModel, + ApiRespProviderEmbeddingModels, + ApiRespProviderEmbeddingModel, + EmbeddingModel, ApiRespPipelines, Pipeline, ApiRespPlatformAdapters, @@ -31,6 +34,10 @@ import { AsyncTask, ApiRespWebChatMessage, ApiRespWebChatMessages, + ApiRespKnowledgeBases, + ApiRespKnowledgeBase, + KnowledgeBase, + ApiRespKnowledgeBaseFiles, } from '@/app/infra/entities/api'; import { GetBotLogsRequest } from '@/app/infra/http/requestParam/bots/GetBotLogsRequest'; import { GetBotLogsResponse } from '@/app/infra/http/requestParam/bots/GetBotLogsResponse'; @@ -226,8 +233,10 @@ class HttpClient { // real api request implementation // ============ Provider API ============ - public getProviderRequesters(): Promise { - return this.get('/api/v1/provider/requesters'); + public getProviderRequesters( + model_type: string, + ): Promise { + return this.get('/api/v1/provider/requesters', { type: model_type }); } public getProviderRequester(name: string): Promise { @@ -275,6 +284,39 @@ class HttpClient { return this.post(`/api/v1/provider/models/llm/${uuid}/test`, model); } + // ============ Provider Model Embedding ============ + public getProviderEmbeddingModels(): Promise { + return this.get('/api/v1/provider/models/embedding'); + } + + public getProviderEmbeddingModel( + uuid: string, + ): Promise { + return this.get(`/api/v1/provider/models/embedding/${uuid}`); + } + + public createProviderEmbeddingModel(model: EmbeddingModel): Promise { + return this.post('/api/v1/provider/models/embedding', model); + } + + public deleteProviderEmbeddingModel(uuid: string): Promise { + return this.delete(`/api/v1/provider/models/embedding/${uuid}`); + } + + public updateProviderEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.put(`/api/v1/provider/models/embedding/${uuid}`, model); + } + + public testEmbeddingModel( + uuid: string, + model: EmbeddingModel, + ): Promise { + return this.post(`/api/v1/provider/models/embedding/${uuid}/test`, model); + } + // ============ Pipeline API ============ public getGeneralPipelineMetadata(): Promise { // as designed, this method will be deprecated, and only for developer to check the prefered config schema @@ -389,6 +431,67 @@ class HttpClient { return this.post(`/api/v1/platform/bots/${botId}/logs`, request); } + // ============ File management API ============ + public uploadDocumentFile(file: File): Promise<{ file_id: string }> { + const formData = new FormData(); + formData.append('file', file); + + return this.request<{ file_id: string }>({ + method: 'post', + url: '/api/v1/files/documents', + data: formData, + headers: { + 'Content-Type': 'multipart/form-data', + }, + }); + } + + // ============ Knowledge Base API ============ + public getKnowledgeBases(): Promise { + return this.get('/api/v1/knowledge/bases'); + } + + public getKnowledgeBase(uuid: string): Promise { + return this.get(`/api/v1/knowledge/bases/${uuid}`); + } + + public createKnowledgeBase(base: KnowledgeBase): Promise<{ uuid: string }> { + return this.post('/api/v1/knowledge/bases', base); + } + + public updateKnowledgeBase( + uuid: string, + base: KnowledgeBase, + ): Promise<{ uuid: string }> { + return this.put(`/api/v1/knowledge/bases/${uuid}`, base); + } + + public uploadKnowledgeBaseFile( + uuid: string, + file_id: string, + ): Promise { + return this.post(`/api/v1/knowledge/bases/${uuid}/files`, { + file_id, + }); + } + + public getKnowledgeBaseFiles( + uuid: string, + ): Promise { + return this.get(`/api/v1/knowledge/bases/${uuid}/files`); + } + + public deleteKnowledgeBaseFile( + uuid: string, + file_id: string, + ): Promise { + return this.delete(`/api/v1/knowledge/bases/${uuid}/files/${file_id}`); + } + + public deleteKnowledgeBase(uuid: string): Promise { + return this.delete(`/api/v1/knowledge/bases/${uuid}`); + } + // ============ Plugins API ============ public getPlugins(): Promise { return this.get('/api/v1/plugins'); diff --git a/web/src/components/ui/dropdown-menu.tsx b/web/src/components/ui/dropdown-menu.tsx new file mode 100644 index 00000000..26027549 --- /dev/null +++ b/web/src/components/ui/dropdown-menu.tsx @@ -0,0 +1,257 @@ +'use client'; + +import * as React from 'react'; +import * as DropdownMenuPrimitive from '@radix-ui/react-dropdown-menu'; +import { CheckIcon, ChevronRightIcon, CircleIcon } from 'lucide-react'; + +import { cn } from '@/lib/utils'; + +function DropdownMenu({ + ...props +}: React.ComponentProps) { + return ; +} + +function DropdownMenuPortal({ + ...props +}: React.ComponentProps) { + return ( + + ); +} + +function DropdownMenuTrigger({ + ...props +}: React.ComponentProps) { + return ( + + ); +} + +function DropdownMenuContent({ + className, + sideOffset = 4, + ...props +}: React.ComponentProps) { + return ( + + + + ); +} + +function DropdownMenuGroup({ + ...props +}: React.ComponentProps) { + return ( + + ); +} + +function DropdownMenuItem({ + className, + inset, + variant = 'default', + ...props +}: React.ComponentProps & { + inset?: boolean; + variant?: 'default' | 'destructive'; +}) { + return ( + + ); +} + +function DropdownMenuCheckboxItem({ + className, + children, + checked, + ...props +}: React.ComponentProps) { + return ( + + + + + + + {children} + + ); +} + +function DropdownMenuRadioGroup({ + ...props +}: React.ComponentProps) { + return ( + + ); +} + +function DropdownMenuRadioItem({ + className, + children, + ...props +}: React.ComponentProps) { + return ( + + + + + + + {children} + + ); +} + +function DropdownMenuLabel({ + className, + inset, + ...props +}: React.ComponentProps & { + inset?: boolean; +}) { + return ( + + ); +} + +function DropdownMenuSeparator({ + className, + ...props +}: React.ComponentProps) { + return ( + + ); +} + +function DropdownMenuShortcut({ + className, + ...props +}: React.ComponentProps<'span'>) { + return ( + + ); +} + +function DropdownMenuSub({ + ...props +}: React.ComponentProps) { + return ; +} + +function DropdownMenuSubTrigger({ + className, + inset, + children, + ...props +}: React.ComponentProps & { + inset?: boolean; +}) { + return ( + + {children} + + + ); +} + +function DropdownMenuSubContent({ + className, + ...props +}: React.ComponentProps) { + return ( + + ); +} + +export { + DropdownMenu, + DropdownMenuPortal, + DropdownMenuTrigger, + DropdownMenuContent, + DropdownMenuGroup, + DropdownMenuLabel, + DropdownMenuItem, + DropdownMenuCheckboxItem, + DropdownMenuRadioGroup, + DropdownMenuRadioItem, + DropdownMenuSeparator, + DropdownMenuShortcut, + DropdownMenuSub, + DropdownMenuSubTrigger, + DropdownMenuSubContent, +}; diff --git a/web/src/components/ui/table.tsx b/web/src/components/ui/table.tsx new file mode 100644 index 00000000..ebded8ed --- /dev/null +++ b/web/src/components/ui/table.tsx @@ -0,0 +1,116 @@ +'use client'; + +import * as React from 'react'; + +import { cn } from '@/lib/utils'; + +function Table({ className, ...props }: React.ComponentProps<'table'>) { + return ( +
+ + + ); +} + +function TableHeader({ className, ...props }: React.ComponentProps<'thead'>) { + return ( + + ); +} + +function TableBody({ className, ...props }: React.ComponentProps<'tbody'>) { + return ( + + ); +} + +function TableFooter({ className, ...props }: React.ComponentProps<'tfoot'>) { + return ( + tr]:last:border-b-0', + className, + )} + {...props} + /> + ); +} + +function TableRow({ className, ...props }: React.ComponentProps<'tr'>) { + return ( + + ); +} + +function TableHead({ className, ...props }: React.ComponentProps<'th'>) { + return ( +
[role=checkbox]]:translate-y-[2px]', + className, + )} + {...props} + /> + ); +} + +function TableCell({ className, ...props }: React.ComponentProps<'td'>) { + return ( + [role=checkbox]]:translate-y-[2px]', + className, + )} + {...props} + /> + ); +} + +function TableCaption({ + className, + ...props +}: React.ComponentProps<'caption'>) { + return ( +
+ ); +} + +export { + Table, + TableHeader, + TableBody, + TableFooter, + TableHead, + TableRow, + TableCell, + TableCaption, +}; diff --git a/web/src/i18n/locales/en-US.ts b/web/src/i18n/locales/en-US.ts index 1975a521..e194c58b 100644 --- a/web/src/i18n/locales/en-US.ts +++ b/web/src/i18n/locales/en-US.ts @@ -86,14 +86,13 @@ const enUS = { string: 'String', number: 'Number', boolean: 'Boolean', - extraParametersDescription: - 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', selectModelProvider: 'Select Model Provider', modelProviderDescription: 'Please fill in the model name provided by the supplier', selectModel: 'Select Model', testSuccess: 'Test successful', testError: 'Test failed, please check your model configuration', + llmModels: 'LLM Models', }, bots: { title: 'Bots', @@ -231,6 +230,55 @@ const enUS = { atTips: 'Mention the bot', }, }, + knowledge: { + title: 'Knowledge', + createKnowledgeBase: 'Create Knowledge Base', + editKnowledgeBase: 'Edit Knowledge Base', + selectKnowledgeBase: 'Select Knowledge Base', + empty: 'Empty', + editDocument: 'Documents', + description: 'Configuring knowledge bases for improved LLM responses', + metadata: 'Metadata', + documents: 'Documents', + kbNameRequired: 'Knowledge base name cannot be empty', + kbDescriptionRequired: 'Knowledge base description cannot be empty', + embeddingModelUUIDRequired: 'Embedding model cannot be empty', + daysAgo: 'days ago', + today: 'Today', + kbName: 'Knowledge Base Name', + kbDescription: 'Knowledge Base Description', + defaultDescription: 'A knowledge base', + embeddingModelUUID: 'Embedding Model', + selectEmbeddingModel: 'Select Embedding Model', + embeddingModelDescription: + 'Used to vectorize the text, you can configure it in the Models page', + updateTime: 'Updated ', + cannotChangeEmbeddingModel: + 'Knowledge base created cannot be modified embedding model', + updateKnowledgeBaseSuccess: 'Knowledge base updated successfully', + updateKnowledgeBaseFailed: 'Knowledge base update failed', + documentsTab: { + name: 'Name', + status: 'Status', + noResults: 'No documents', + dragAndDrop: 'Drag and drop files here or click to upload', + uploading: 'Uploading...', + supportedFormats: + 'Supports PDF, Word, TXT, Markdown and other document formats', + uploadSuccess: 'File uploaded successfully!', + uploadError: 'File upload failed, please try again', + uploadingFile: 'Uploading file...', + actions: 'Actions', + delete: 'Delete File', + fileDeleteSuccess: 'File deleted successfully', + fileDeleteFailed: 'File deletion failed', + processing: 'Processing', + completed: 'Completed', + failed: 'Failed', + }, + deleteKnowledgeBaseConfirmation: + 'Are you sure you want to delete this knowledge base? All documents in this knowledge base will be deleted.', + }, register: { title: 'Initialize LangBot 👋', description: 'This is your first time starting LangBot', @@ -259,6 +307,21 @@ const enUS = { 'Password reset failed, please check your email and recovery key', backToLogin: 'Back to Login', }, + embedding: { + description: 'Manage Embedding models for text vectorization', + createModel: 'Create Embedding Model', + editModel: 'Edit Embedding Model', + getModelListError: 'Failed to get Embedding model list: ', + embeddingModels: 'Embedding', + extraParametersDescription: + 'Will be attached to the request body, such as encoding_format, dimensions, etc.', + }, + llm: { + description: 'Manage LLM models for conversation generation', + llmModels: 'LLM', + extraParametersDescription: + 'Will be attached to the request body, such as max_tokens, temperature, top_p, etc.', + }, }; export default enUS; diff --git a/web/src/i18n/locales/ja-JP.ts b/web/src/i18n/locales/ja-JP.ts index bac6f805..a5ea9c04 100644 --- a/web/src/i18n/locales/ja-JP.ts +++ b/web/src/i18n/locales/ja-JP.ts @@ -232,6 +232,56 @@ const jaJP = { atTips: 'ボットをメンション', }, }, + knowledge: { + title: '知識ベース', + createKnowledgeBase: '知識ベースを作成', + editKnowledgeBase: '知識ベースを編集', + selectKnowledgeBase: '知識ベースを選択', + empty: 'なし', + editDocument: 'ドキュメント', + description: 'LLMの回答品質向上のための知識ベースを設定します', + metadata: 'メタデータ', + documents: 'ドキュメント', + kbNameRequired: '知識ベース名は必須です', + kbDescriptionRequired: '知識ベースの説明は必須です', + embeddingModelUUIDRequired: '埋め込みモデルは必須です', + daysAgo: '日前', + today: '今日', + kbName: '知識ベース名', + kbDescription: '知識ベースの説明', + defaultDescription: '知識ベース', + embeddingModelUUID: '埋め込みモデル', + selectEmbeddingModel: '埋め込みモデルを選択', + embeddingModelDescription: + 'テキストのベクトル化に使用する埋め込みモデルを管理します', + updateTime: '更新日時', + cannotChangeEmbeddingModel: + '知識ベース作成後は埋め込みモデルを変更できません', + updateKnowledgeBaseSuccess: '知識ベースの更新に成功しました', + updateKnowledgeBaseFailed: '知識ベースの更新に失敗しました', + documentsTab: { + name: '名前', + status: 'ステータス', + noResults: 'ドキュメントがありません', + dragAndDrop: + 'ファイルをここにドラッグ&ドロップするか、クリックしてアップロードしてください', + uploading: 'アップロード中...', + supportedFormats: + 'PDF、Word、TXT、Markdownなどのドキュメントファイルをサポートしています', + uploadSuccess: 'ファイルのアップロードに成功しました!', + uploadError: 'ファイルのアップロードに失敗しました。再度お試しください', + uploadingFile: 'ファイルをアップロード中...', + actions: 'アクション', + delete: 'ドキュメントを削除', + fileDeleteSuccess: 'ドキュメントの削除に成功しました', + fileDeleteFailed: 'ドキュメントの削除に失敗しました', + processing: '処理中', + completed: '完了', + failed: '失敗', + }, + deleteKnowledgeBaseConfirmation: + '本当にこの知識ベースを削除しますか?この知識ベースに紐付けられたドキュメントは削除されます。', + }, register: { title: 'LangBot を初期化 👋', description: 'これはLangBotの初回起動です', @@ -260,6 +310,21 @@ const jaJP = { 'パスワードのリセットに失敗しました。メールアドレスと復旧キーを確認してください', backToLogin: 'ログインに戻る', }, + embedding: { + description: 'テキストのベクトル化に使用する埋め込みモデルを管理します', + createModel: '埋め込みモデルを作成', + editModel: '埋め込みモデルを編集', + getModelListError: '埋め込みモデルリストの取得に失敗しました:', + embeddingModels: '埋め込みモデル', + extraParametersDescription: + 'リクエストボディに追加されるパラメータ(encoding_format、dimensions など)', + }, + llm: { + description: 'チャットメッセージの生成に使用するLLMモデルを管理します', + llmModels: 'LLMモデル', + extraParametersDescription: + 'リクエストボディに追加されるパラメータ(max_tokens、temperature、top_p など)', + }, }; export default jaJP; diff --git a/web/src/i18n/locales/zh-Hans.ts b/web/src/i18n/locales/zh-Hans.ts index 2ded8236..621bb16c 100644 --- a/web/src/i18n/locales/zh-Hans.ts +++ b/web/src/i18n/locales/zh-Hans.ts @@ -87,13 +87,12 @@ const zhHans = { string: '字符串', number: '数字', boolean: '布尔值', - extraParametersDescription: - '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', selectModelProvider: '选择模型供应商', modelProviderDescription: '请填写供应商向您提供的模型名称', selectModel: '请选择模型', testSuccess: '测试成功', testError: '测试失败,请检查模型配置', + llmModels: '对话模型', }, bots: { title: '机器人', @@ -226,6 +225,52 @@ const zhHans = { atTips: '提及机器人', }, }, + knowledge: { + title: '知识库', + createKnowledgeBase: '创建知识库', + editKnowledgeBase: '编辑知识库', + selectKnowledgeBase: '选择知识库', + empty: '无', + editDocument: '文档', + description: '配置可用于提升模型回复质量的知识库', + metadata: '元数据', + documents: '文档', + kbNameRequired: '知识库名称不能为空', + kbDescriptionRequired: '知识库描述不能为空', + embeddingModelUUIDRequired: '嵌入模型不能为空', + daysAgo: '天前', + today: '今天', + kbName: '知识库名称', + kbDescription: '知识库描述', + defaultDescription: '一个知识库', + embeddingModelUUID: '嵌入模型', + selectEmbeddingModel: '选择嵌入模型', + embeddingModelDescription: '用于向量化文本,可在模型配置页面配置', + updateTime: '更新于', + cannotChangeEmbeddingModel: '知识库创建后不可修改嵌入模型', + updateKnowledgeBaseSuccess: '知识库更新成功', + updateKnowledgeBaseFailed: '知识库更新失败', + documentsTab: { + name: '名称', + status: '状态', + noResults: '暂无文档', + dragAndDrop: '拖拽文件到此处或点击上传', + uploading: '上传中...', + supportedFormats: '支持 PDF、Word、TXT、Markdown 等文档格式', + uploadSuccess: '文件上传成功!', + uploadError: '文件上传失败,请重试', + uploadingFile: '上传文件中...', + actions: '操作', + delete: '删除文件', + fileDeleteSuccess: '文件删除成功', + fileDeleteFailed: '文件删除失败', + processing: '处理中', + completed: '完成', + failed: '失败', + }, + deleteKnowledgeBaseConfirmation: + '你确定要删除这个知识库吗?此知识库下的所有文档将被删除。', + }, register: { title: '初始化 LangBot 👋', description: '这是您首次启动 LangBot', @@ -251,6 +296,21 @@ const zhHans = { resetFailed: '密码重置失败,请检查邮箱和恢复密钥是否正确', backToLogin: '返回登录', }, + embedding: { + description: '管理嵌入模型,用于向量化文本', + createModel: '创建嵌入模型', + editModel: '编辑嵌入模型', + getModelListError: '获取嵌入模型列表失败:', + embeddingModels: '嵌入模型', + extraParametersDescription: + '将在请求时附加到请求体中,如 encoding_format, dimensions 等', + }, + llm: { + llmModels: '对话模型', + description: '管理 LLM 模型,用于对话消息生成', + extraParametersDescription: + '将在请求时附加到请求体中,如 max_tokens, temperature, top_p 等', + }, }; export default zhHans;