import json import asyncio import aiohttp import io from typing import Dict, List, Any, AsyncGenerator import os from pathlib import Path class AsyncCozeAPIClient: def __init__(self, api_key: str, api_base: str = 'https://api.coze.cn'): self.api_key = api_key self.api_base = api_base self.session = None async def __aenter__(self): """支持异步上下文管理器""" await self.coze_session() return self async def __aexit__(self, exc_type, exc_val, exc_tb): """退出时自动关闭会话""" await self.close() async def coze_session(self): """确保HTTP session存在""" if self.session is None: connector = aiohttp.TCPConnector( ssl=False if self.api_base.startswith('http://') else True, limit=100, limit_per_host=30, keepalive_timeout=30, enable_cleanup_closed=True, ) timeout = aiohttp.ClientTimeout( total=120, # 默认超时时间 connect=30, sock_read=120, ) headers = { 'Authorization': f'Bearer {self.api_key}', 'Accept': 'text/event-stream', } self.session = aiohttp.ClientSession(headers=headers, timeout=timeout, connector=connector) return self.session async def close(self): """显式关闭会话""" if self.session and not self.session.closed: await self.session.close() self.session = None async def upload( self, file, ) -> str: # 处理 Path 对象 if isinstance(file, Path): if not file.exists(): raise ValueError(f'File not found: {file}') with open(file, 'rb') as f: file = f.read() # 处理文件路径字符串 elif isinstance(file, str): if not os.path.isfile(file): raise ValueError(f'File not found: {file}') with open(file, 'rb') as f: file = f.read() # 处理文件对象 elif hasattr(file, 'read'): file = file.read() session = await self.coze_session() url = f'{self.api_base}/v1/files/upload' try: file_io = io.BytesIO(file) async with session.post( url, data={ 'file': file_io, }, timeout=aiohttp.ClientTimeout(total=60), ) as response: if response.status == 401: raise Exception('Coze API 认证失败,请检查 API Key 是否正确') response_text = await response.text() if response.status != 200: raise Exception(f'文件上传失败,状态码: {response.status}, 响应: {response_text}') try: result = await response.json() except json.JSONDecodeError: raise Exception(f'文件上传响应解析失败: {response_text}') if result.get('code') != 0: raise Exception(f'文件上传失败: {result.get("msg", "未知错误")}') file_id = result['data']['id'] return file_id except asyncio.TimeoutError: raise Exception('文件上传超时') except Exception as e: raise Exception(f'文件上传失败: {str(e)}') async def chat_messages( self, bot_id: str, user_id: str, additional_messages: List[Dict] | None = None, conversation_id: str | None = None, auto_save_history: bool = True, stream: bool = True, timeout: float = 120, ) -> AsyncGenerator[Dict[str, Any], None]: """发送聊天消息并返回流式响应 Args: bot_id: Bot ID user_id: 用户ID additional_messages: 额外消息列表 conversation_id: 会话ID auto_save_history: 是否自动保存历史 stream: 是否流式响应 timeout: 超时时间 """ session = await self.coze_session() url = f'{self.api_base}/v3/chat' payload = { 'bot_id': bot_id, 'user_id': user_id, 'stream': stream, 'auto_save_history': auto_save_history, } if additional_messages: payload['additional_messages'] = additional_messages params = {} if conversation_id: params['conversation_id'] = conversation_id try: async with session.post( url, json=payload, params=params, timeout=aiohttp.ClientTimeout(total=timeout), ) as response: if response.status == 401: raise Exception('Coze API 认证失败,请检查 API Key 是否正确') if response.status != 200: raise Exception(f'Coze API 流式请求失败,状态码: {response.status}') async for chunk in response.content: chunk = chunk.decode('utf-8') if chunk != '\n': if chunk.startswith('event:'): chunk_type = chunk.replace('event:', '', 1).strip() elif chunk.startswith('data:'): chunk_data = chunk.replace('data:', '', 1).strip() else: yield { 'event': chunk_type, 'data': json.loads(chunk_data) if chunk_data else {}, } # 处理本地部署时,接口返回的data为空值 except asyncio.TimeoutError: raise Exception(f'Coze API 流式请求超时 ({timeout}秒)') except Exception as e: raise Exception(f'Coze API 流式请求失败: {str(e)}')