diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 9e3014af..55ee0687 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator import openai import openai.types.chat.chat_completion as chat_completion +import openai.types.chat.chat_completion_message_tool_call as chat_completion_message_tool_call import httpx import aiohttp import async_lru @@ -40,6 +41,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): timeout=self.requester_cfg['timeout'], http_client=httpx.AsyncClient( trust_env=True, + timeout=self.requester_cfg['timeout'] ) ) @@ -47,7 +49,67 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): self, args: dict, ) -> chat_completion.ChatCompletion: - return await self.client.chat.completions.create(**args) + args["stream"] = True + + chunk = None + + pending_content = "" + + tool_calls = [] + + resp_gen: openai.AsyncStream = await self.client.chat.completions.create(**args) + + async for chunk in resp_gen: + # print(chunk) + if not chunk: + continue + + if chunk.choices[0].delta.content is not None: + pending_content += chunk.choices[0].delta.content + + if chunk.choices[0].delta.tool_calls is not None: + for tool_call in chunk.choices[0].delta.tool_calls: + for tc in tool_calls: + if tc.index == tool_call.index: + tc.function.arguments += tool_call.function.arguments + break + else: + tool_calls.append(tool_call) + + real_tool_calls = [] + + for tc in tool_calls: + function = chat_completion_message_tool_call.Function( + name=tc.function.name, + arguments=tc.function.arguments + ) + real_tool_calls.append(chat_completion_message_tool_call.ChatCompletionMessageToolCall( + id=tc.id, + function=function, + type="function" + )) + + return chat_completion.ChatCompletion( + id=chunk.id, + object="chat.completion", + created=chunk.created, + choices=[ + chat_completion.Choice( + index=0, + message=chat_completion.ChatCompletionMessage( + role="assistant", + content=pending_content, + tool_calls=real_tool_calls if len(real_tool_calls) > 0 else None + ), + finish_reason=chunk.choices[0].finish_reason, + logprobs=chunk.choices[0].logprobs, + ) + ], + model=args["model"], + service_tier=chunk.service_tier, + system_fingerprint=chunk.system_fingerprint, + usage=chunk.usage + ) if chunk else None async def _make_msg( self, diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index bf414745..7b8c9ca8 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -46,6 +46,9 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): # 发送请求 resp = await self._req(args) + if resp is None: + raise errors.RequesterError('接口返回为空,请确定模型提供商服务是否正常') + # 处理请求结果 message = await self._make_msg(resp)