mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-12 16:56:02 +00:00
181 lines
5.7 KiB
Python
181 lines
5.7 KiB
Python
from __future__ import annotations
|
||
|
||
import httpx
|
||
import typing
|
||
import json
|
||
|
||
from .errors import WeKnoraAPIError
|
||
|
||
|
||
class AsyncWeKnoraClient:
|
||
"""WeKnora API 客户端"""
|
||
|
||
api_key: str
|
||
base_url: str
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str,
|
||
base_url: str = 'http://localhost:80/api/v1',
|
||
) -> None:
|
||
self.api_key = api_key
|
||
self.base_url = base_url
|
||
|
||
async def create_session(
|
||
self,
|
||
title: str = '',
|
||
description: str = '',
|
||
timeout: float = 30.0,
|
||
) -> str:
|
||
"""创建会话,返回 session_id"""
|
||
async with httpx.AsyncClient(
|
||
base_url=self.base_url,
|
||
trust_env=True,
|
||
timeout=timeout,
|
||
) as client:
|
||
payload: dict[str, typing.Any] = {}
|
||
if title:
|
||
payload['title'] = title
|
||
if description:
|
||
payload['description'] = description
|
||
|
||
response = await client.post(
|
||
'/sessions',
|
||
headers={
|
||
'X-API-Key': self.api_key,
|
||
'Content-Type': 'application/json',
|
||
},
|
||
json=payload,
|
||
)
|
||
|
||
if response.status_code not in (200, 201):
|
||
raise WeKnoraAPIError(f'{response.status_code} {response.text}')
|
||
|
||
data = response.json()
|
||
return data['data']['id']
|
||
|
||
async def agent_chat(
|
||
self,
|
||
session_id: str,
|
||
query: str,
|
||
user: str,
|
||
agent_id: str = '',
|
||
knowledge_base_ids: list[str] | None = None,
|
||
web_search_enabled: bool = False,
|
||
timeout: float = 120.0,
|
||
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
||
"""
|
||
Agent 智能对话(SSE 流式)
|
||
|
||
响应事件类型:
|
||
- agent_query: Agent 开始处理
|
||
- thinking: 思考过程
|
||
- tool_call: 工具调用
|
||
- tool_result: 工具结果
|
||
- references: 知识库引用
|
||
- answer: 回答内容
|
||
- reflection: 反思
|
||
- session_title: 会话标题
|
||
- error: 错误
|
||
"""
|
||
if knowledge_base_ids is None:
|
||
knowledge_base_ids = []
|
||
|
||
async with httpx.AsyncClient(
|
||
base_url=self.base_url,
|
||
trust_env=True,
|
||
timeout=timeout,
|
||
) as client:
|
||
payload: dict[str, typing.Any] = {
|
||
'query': query,
|
||
'agent_enabled': True,
|
||
'channel': 'im',
|
||
}
|
||
if agent_id:
|
||
payload['agent_id'] = agent_id
|
||
if knowledge_base_ids:
|
||
payload['knowledge_base_ids'] = knowledge_base_ids
|
||
if web_search_enabled:
|
||
payload['web_search_enabled'] = True
|
||
|
||
async with client.stream(
|
||
'POST',
|
||
f'/agent-chat/{session_id}',
|
||
headers={
|
||
'X-API-Key': self.api_key,
|
||
'Content-Type': 'application/json',
|
||
},
|
||
json=payload,
|
||
) as r:
|
||
async for chunk in r.aiter_lines():
|
||
if r.status_code != 200:
|
||
raise WeKnoraAPIError(f'{r.status_code} {chunk}')
|
||
if chunk.strip() == '':
|
||
continue
|
||
if chunk.startswith('data:'):
|
||
try:
|
||
data = json.loads(chunk[5:].strip())
|
||
except json.JSONDecodeError:
|
||
continue
|
||
yield data
|
||
# 收到 error 事件后主动结束流,避免上层未 raise 时持续等待
|
||
if data.get('response_type') == 'error':
|
||
return
|
||
|
||
async def knowledge_chat(
|
||
self,
|
||
session_id: str,
|
||
query: str,
|
||
user: str,
|
||
agent_id: str = 'builtin-quick-answer',
|
||
knowledge_base_ids: list[str] | None = None,
|
||
timeout: float = 120.0,
|
||
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
||
"""
|
||
知识库 RAG 问答(SSE 流式)
|
||
|
||
响应事件类型:
|
||
- references: 知识库引用
|
||
- answer: 回答内容
|
||
"""
|
||
if knowledge_base_ids is None:
|
||
knowledge_base_ids = []
|
||
|
||
async with httpx.AsyncClient(
|
||
base_url=self.base_url,
|
||
trust_env=True,
|
||
timeout=timeout,
|
||
) as client:
|
||
payload: dict[str, typing.Any] = {
|
||
'query': query,
|
||
'channel': 'im',
|
||
}
|
||
if agent_id:
|
||
payload['agent_id'] = agent_id
|
||
if knowledge_base_ids:
|
||
payload['knowledge_base_ids'] = knowledge_base_ids
|
||
|
||
async with client.stream(
|
||
'POST',
|
||
f'/knowledge-chat/{session_id}',
|
||
headers={
|
||
'X-API-Key': self.api_key,
|
||
'Content-Type': 'application/json',
|
||
},
|
||
json=payload,
|
||
) as r:
|
||
async for chunk in r.aiter_lines():
|
||
if r.status_code != 200:
|
||
raise WeKnoraAPIError(f'{r.status_code} {chunk}')
|
||
if chunk.strip() == '':
|
||
continue
|
||
if chunk.startswith('data:'):
|
||
try:
|
||
data = json.loads(chunk[5:].strip())
|
||
except json.JSONDecodeError:
|
||
continue
|
||
yield data
|
||
# 收到 error 事件后主动结束流,避免上层未 raise 时持续等待
|
||
if data.get('response_type') == 'error':
|
||
return
|