Files
LangBot/src/langbot/libs/weknora_api/client.py
Typer_Body 0c6f71738c deerflow
2026-06-07 02:17:40 +08:00

181 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import httpx
import typing
import json
from .errors import WeKnoraAPIError
class AsyncWeKnoraClient:
"""WeKnora API 客户端"""
api_key: str
base_url: str
def __init__(
self,
api_key: str,
base_url: str = 'http://localhost:80/api/v1',
) -> None:
self.api_key = api_key
self.base_url = base_url
async def create_session(
self,
title: str = '',
description: str = '',
timeout: float = 30.0,
) -> str:
"""创建会话,返回 session_id"""
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
payload: dict[str, typing.Any] = {}
if title:
payload['title'] = title
if description:
payload['description'] = description
response = await client.post(
'/sessions',
headers={
'X-API-Key': self.api_key,
'Content-Type': 'application/json',
},
json=payload,
)
if response.status_code not in (200, 201):
raise WeKnoraAPIError(f'{response.status_code} {response.text}')
data = response.json()
return data['data']['id']
async def agent_chat(
self,
session_id: str,
query: str,
user: str,
agent_id: str = '',
knowledge_base_ids: list[str] | None = None,
web_search_enabled: bool = False,
timeout: float = 120.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""
Agent 智能对话SSE 流式)
响应事件类型:
- agent_query: Agent 开始处理
- thinking: 思考过程
- tool_call: 工具调用
- tool_result: 工具结果
- references: 知识库引用
- answer: 回答内容
- reflection: 反思
- session_title: 会话标题
- error: 错误
"""
if knowledge_base_ids is None:
knowledge_base_ids = []
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
payload: dict[str, typing.Any] = {
'query': query,
'agent_enabled': True,
'channel': 'im',
}
if agent_id:
payload['agent_id'] = agent_id
if knowledge_base_ids:
payload['knowledge_base_ids'] = knowledge_base_ids
if web_search_enabled:
payload['web_search_enabled'] = True
async with client.stream(
'POST',
f'/agent-chat/{session_id}',
headers={
'X-API-Key': self.api_key,
'Content-Type': 'application/json',
},
json=payload,
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise WeKnoraAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == '':
continue
if chunk.startswith('data:'):
try:
data = json.loads(chunk[5:].strip())
except json.JSONDecodeError:
continue
yield data
# 收到 error 事件后主动结束流,避免上层未 raise 时持续等待
if data.get('response_type') == 'error':
return
async def knowledge_chat(
self,
session_id: str,
query: str,
user: str,
agent_id: str = 'builtin-quick-answer',
knowledge_base_ids: list[str] | None = None,
timeout: float = 120.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""
知识库 RAG 问答SSE 流式)
响应事件类型:
- references: 知识库引用
- answer: 回答内容
"""
if knowledge_base_ids is None:
knowledge_base_ids = []
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
payload: dict[str, typing.Any] = {
'query': query,
'channel': 'im',
}
if agent_id:
payload['agent_id'] = agent_id
if knowledge_base_ids:
payload['knowledge_base_ids'] = knowledge_base_ids
async with client.stream(
'POST',
f'/knowledge-chat/{session_id}',
headers={
'X-API-Key': self.api_key,
'Content-Type': 'application/json',
},
json=payload,
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise WeKnoraAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == '':
continue
if chunk.startswith('data:'):
try:
data = json.loads(chunk[5:].strip())
except json.JSONDecodeError:
continue
yield data
# 收到 error 事件后主动结束流,避免上层未 raise 时持续等待
if data.get('response_type') == 'error':
return