feat: 添加对 agent 应用的支持 (#951)

This commit is contained in:
Junyan Qin
2024-12-17 00:41:28 +08:00
parent 32b400dcb1
commit 6642498f00
5 changed files with 103 additions and 25 deletions
+3 -3
View File
@@ -10,8 +10,8 @@ class TestDifyClient:
async def test_chat_messages(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
resp = await cln.chat_messages(inputs={}, query="Who are you?", user="test")
print(json.dumps(resp, ensure_ascii=False, indent=4))
async for chunk in cln.chat_messages(inputs={}, query="调用工具查看现在几点?", user="test"):
print(json.dumps(chunk, ensure_ascii=False, indent=4))
async def test_upload_file(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
@@ -41,4 +41,4 @@ class TestDifyClient:
print(json.dumps(chunks, ensure_ascii=False, indent=4))
if __name__ == "__main__":
asyncio.run(TestDifyClient().test_workflow_run())
asyncio.run(TestDifyClient().test_chat_messages())
+16 -11
View File
@@ -26,21 +26,22 @@ class AsyncDifyServiceClient:
inputs: dict[str, typing.Any],
query: str,
user: str,
response_mode: str = "blocking", # 当前不支持 streaming
response_mode: str = "streaming", # 当前不支持 blocking
conversation_id: str = "",
files: list[dict[str, typing.Any]] = [],
timeout: float = 30.0,
) -> dict[str, typing.Any]:
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""发送消息"""
if response_mode != "blocking":
raise DifyAPIError("当前仅支持 blocking 模式")
if response_mode != "streaming":
raise DifyAPIError("当前仅支持 streaming 模式")
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
response = await client.post(
async with client.stream(
"POST",
"/chat-messages",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
json={
@@ -51,12 +52,14 @@ class AsyncDifyServiceClient:
"conversation_id": conversation_id,
"files": files,
},
)
if response.status_code != 200:
raise DifyAPIError(f"{response.status_code} {response.text}")
return response.json()
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}")
if chunk.strip() == "":
continue
if chunk.startswith("data:"):
yield json.loads(chunk[5:])
async def workflow_run(
self,
@@ -88,6 +91,8 @@ class AsyncDifyServiceClient:
},
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}")
if chunk.strip() == "":
continue
if chunk.startswith("data:"):