From 6642498f00dafff038a2c3268e7d61717df26adb Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Tue, 17 Dec 2024 00:41:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=AF=B9=20agent=20?= =?UTF-8?q?=E5=BA=94=E7=94=A8=E7=9A=84=E6=94=AF=E6=8C=81=20(#951)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- libs/dify_service_api/test.py | 6 +- libs/dify_service_api/v1/client.py | 27 +++--- .../m017_dify_api_timeout_params.py | 7 +- pkg/provider/runners/difysvapi.py | 84 ++++++++++++++++--- templates/provider.json | 4 + 5 files changed, 103 insertions(+), 25 deletions(-) diff --git a/libs/dify_service_api/test.py b/libs/dify_service_api/test.py index 4c2662fa..faf7571a 100644 --- a/libs/dify_service_api/test.py +++ b/libs/dify_service_api/test.py @@ -10,8 +10,8 @@ class TestDifyClient: async def test_chat_messages(self): cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL")) - resp = await cln.chat_messages(inputs={}, query="Who are you?", user="test") - print(json.dumps(resp, ensure_ascii=False, indent=4)) + async for chunk in cln.chat_messages(inputs={}, query="调用工具查看现在几点?", user="test"): + print(json.dumps(chunk, ensure_ascii=False, indent=4)) async def test_upload_file(self): cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL")) @@ -41,4 +41,4 @@ class TestDifyClient: print(json.dumps(chunks, ensure_ascii=False, indent=4)) if __name__ == "__main__": - asyncio.run(TestDifyClient().test_workflow_run()) + asyncio.run(TestDifyClient().test_chat_messages()) diff --git a/libs/dify_service_api/v1/client.py b/libs/dify_service_api/v1/client.py index 91b60052..efa70ea5 100644 --- a/libs/dify_service_api/v1/client.py +++ b/libs/dify_service_api/v1/client.py @@ -26,21 +26,22 @@ class AsyncDifyServiceClient: inputs: dict[str, typing.Any], query: str, user: str, - response_mode: str = "blocking", # 当前不支持 streaming + response_mode: str = "streaming", # 当前不支持 blocking conversation_id: str = "", files: list[dict[str, typing.Any]] = [], timeout: float = 30.0, - ) -> dict[str, typing.Any]: + ) -> typing.AsyncGenerator[dict[str, typing.Any], None]: """发送消息""" - if response_mode != "blocking": - raise DifyAPIError("当前仅支持 blocking 模式") + if response_mode != "streaming": + raise DifyAPIError("当前仅支持 streaming 模式") async with httpx.AsyncClient( base_url=self.base_url, trust_env=True, timeout=timeout, ) as client: - response = await client.post( + async with client.stream( + "POST", "/chat-messages", headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}, json={ @@ -51,12 +52,14 @@ class AsyncDifyServiceClient: "conversation_id": conversation_id, "files": files, }, - ) - - if response.status_code != 200: - raise DifyAPIError(f"{response.status_code} {response.text}") - - return response.json() + ) as r: + async for chunk in r.aiter_lines(): + if r.status_code != 200: + raise DifyAPIError(f"{r.status_code} {chunk}") + if chunk.strip() == "": + continue + if chunk.startswith("data:"): + yield json.loads(chunk[5:]) async def workflow_run( self, @@ -88,6 +91,8 @@ class AsyncDifyServiceClient: }, ) as r: async for chunk in r.aiter_lines(): + if r.status_code != 200: + raise DifyAPIError(f"{r.status_code} {chunk}") if chunk.strip() == "": continue if chunk.startswith("data:"): diff --git a/pkg/core/migrations/m017_dify_api_timeout_params.py b/pkg/core/migrations/m017_dify_api_timeout_params.py index e0837732..a0e502a4 100644 --- a/pkg/core/migrations/m017_dify_api_timeout_params.py +++ b/pkg/core/migrations/m017_dify_api_timeout_params.py @@ -9,11 +9,16 @@ class DifyAPITimeoutParamsMigration(migration.Migration): async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow'] + return 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow'] \ + or 'agent' not in self.ap.provider_cfg.data['dify-service-api'] async def run(self): """执行迁移""" self.ap.provider_cfg.data['dify-service-api']['chat']['timeout'] = 120 self.ap.provider_cfg.data['dify-service-api']['workflow']['timeout'] = 120 + self.ap.provider_cfg.data['dify-service-api']['agent'] = { + "api-key": "app-1234567890", + "timeout": 120 + } await self.ap.provider_cfg.dump_config() diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index 87c9761c..4fed4277 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -20,7 +20,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): async def initialize(self): """初始化""" - valid_app_types = ["chat", "workflow"] + valid_app_types = ["chat", "agent", "workflow"] if ( self.ap.provider_cfg.data["dify-service-api"]["app-type"] not in valid_app_types @@ -85,23 +85,84 @@ class DifyServiceAPIRunner(runner.RequestRunner): for image_id in image_ids ] - resp = await self.dify_client.chat_messages( + async for chunk in self.dify_client.chat_messages( inputs={}, query=plain_text, user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", conversation_id=cov_id, files=files, timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"], - ) + ): + self.ap.logger.debug("dify-chat-chunk: "+chunk) + if chunk['event'] == 'node_finished': + if chunk['data']['node_type'] == 'answer': + yield llm_entities.Message( + role="assistant", + content=chunk['data']['outputs']['answer'], + ) - msg = llm_entities.Message( - role="assistant", - content=resp["answer"], - ) + query.session.using_conversation.uuid = chunk["conversation_id"] - yield msg + async def _agent_chat_messages( + self, query: core_entities.Query + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """调用聊天助手""" + cov_id = query.session.using_conversation.uuid or "" - query.session.using_conversation.uuid = resp["conversation_id"] + plain_text, image_ids = await self._preprocess_user_message(query) + + files = [ + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": image_id, + } + for image_id in image_ids + ] + + ignored_events = ["agent_message"] + + async for chunk in self.dify_client.chat_messages( + inputs={}, + query=plain_text, + user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", + response_mode="streaming", + conversation_id=cov_id, + files=files, + timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"], + ): + self.ap.logger.debug("dify-agent-chunk: "+chunk) + if chunk["event"] in ignored_events: + continue + if chunk["event"] == "agent_thought": + + if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过 + continue + + if chunk['thought'].strip() != '': # 文字回复内容 + msg = llm_entities.Message( + role="assistant", + content=chunk["thought"], + ) + yield msg + + if chunk['tool']: + msg = llm_entities.Message( + role="assistant", + tool_calls=[ + llm_entities.ToolCall( + id=chunk['id'], + type="function", + function=llm_entities.FunctionCall( + name=chunk["tool"], + arguments=json.dumps({}), + ), + ) + ], + ) + yield msg + + query.session.using_conversation.uuid = chunk["conversation_id"] async def _workflow_messages( self, query: core_entities.Query @@ -136,7 +197,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): files=files, timeout=self.ap.provider_cfg.data["dify-service-api"]["workflow"]["timeout"], ): - + self.ap.logger.debug("dify-workflow-chunk: "+chunk) if chunk["event"] in ignored_events: continue @@ -185,6 +246,9 @@ class DifyServiceAPIRunner(runner.RequestRunner): if self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "chat": async for msg in self._chat_messages(query): yield msg + elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "agent": + async for msg in self._agent_chat_messages(query): + yield msg elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "workflow": async for msg in self._workflow_messages(query): yield msg diff --git a/templates/provider.json b/templates/provider.json index c46ae793..30656f8c 100644 --- a/templates/provider.json +++ b/templates/provider.json @@ -65,6 +65,10 @@ "api-key": "app-1234567890", "timeout": 120 }, + "agent": { + "api-key": "app-1234567890", + "timeout": 120 + }, "workflow": { "api-key": "app-1234567890", "output-key": "summary",