From 5ce32d2f040231bdea67408981424abd8ea8288f Mon Sep 17 00:00:00 2001 From: Dong_master <2213070223@qq.com> Date: Sat, 12 Jul 2025 18:09:24 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8D=E4=BA=86=E5=9B=A0?= =?UTF-8?q?=E4=B8=BA=E8=BF=AD=E4=BB=A3=E6=95=B0=E6=8D=AE=E5=8F=AA=E6=8E=A8?= =?UTF-8?q?=E5=85=A5resq=5Fmessages=E5=92=8Cresq=5Fmessage=5Fchain?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=E7=BC=93=E5=AD=98=E5=88=B0=E5=86=85=E5=AD=98?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=95=B0=E6=8D=AE=E5=92=8C=E5=86=99=E5=85=A5?= =?UTF-8?q?log=E4=B8=AD=E7=9A=84=E6=95=B0=E6=8D=AE=E9=87=8F=E5=BA=9E?= =?UTF-8?q?=E5=A4=A7=EF=BC=8C=E4=BB=A5=E5=8F=8A=E5=B8=A6=E6=9C=89=E6=B7=B1?= =?UTF-8?q?=E5=BA=A6=E6=80=9D=E8=80=83=E6=A8=A1=E5=9E=8B=E7=9A=84think?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/provider/modelmgr/requester.py | 1 - pkg/provider/modelmgr/requesters/chatcmpl.py | 89 +++++++++----------- pkg/provider/runners/localagent.py | 1 - 3 files changed, 42 insertions(+), 49 deletions(-) diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 6b760616..fa4a9ff8 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -83,7 +83,6 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): model: RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: """调用API diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index b3ddea53..844aa83f 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -17,12 +17,15 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): """OpenAI ChatCompletion API 请求器""" client: openai.AsyncClient + is_content:bool default_config: dict[str, typing.Any] = { 'base_url': 'https://api.openai.com/v1', 'timeout': 120, } + + async def initialize(self): self.client = openai.AsyncClient( api_key='', @@ -30,6 +33,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): timeout=self.requester_cfg['timeout'], http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']), ) + self.is_content = False async def _req( self, @@ -69,6 +73,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): async def _make_msg_chunk( self, + index:int, chat_completion: chat_completion.ChatCompletion, ) -> llm_entities.MessageChunk: @@ -83,7 +88,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): delta = chat_completion.delta.model_dump() if hasattr(chat_completion, 'delta') else {} # 确保 role 字段存在且不为 None - # print(delta) + # print(delta.values()) if 'role' not in delta or delta['role'] is None: delta['role'] = 'assistant' @@ -91,8 +96,17 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None # deepseek的reasoner模型 - if reasoning_content is not None: - delta['content'] = '\n' + reasoning_content + '\n\n' + delta['content'] + if reasoning_content is not None and index == 0: + delta['content'] += f'\n{reasoning_content}' + elif reasoning_content is None: + if self.is_content: + delta['content'] = delta['content'] + else: + delta['content'] = f'\n\n\n{delta["content"]}' + self.is_content = True + else: + delta['content'] += reasoning_content + message = llm_entities.MessageChunk(**delta) @@ -135,23 +149,17 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): if stream: current_content = '' args["stream"] = True + chunk_idx = 0 + self.is_content = False async for chunk in self._req_stream(args, extra_body=extra_args): - # print(chunk) - # 处理流式消息 - delta_message = await self._make_msg_chunk(chunk) + delta_message = await self._make_msg_chunk(chunk_idx,chunk) + # print(delta_message) if delta_message.content: current_content += delta_message.content delta_message.content = current_content - print(current_content) - delta_message.all_content = current_content - - # # 检查是否为最后一个块 - # if chunk.finish_reason is not None: - # delta_message.is_final = True - # - # yield delta_message - # 检查结束标志 + # delta_message.all_content = current_content + chunk_idx += 1 chunk_choices = getattr(chunk, 'choices', None) if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): delta_message.is_final = True @@ -215,9 +223,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: + ) -> llm_entities.Message: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) @@ -231,26 +238,14 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): try: - if stream: - async for item in self._closure_stream( - query=query, - req_messages=req_messages, - use_model=model, - use_funcs=funcs, - stream=stream, - extra_args=extra_args, - ): - return item - else: - print(req_messages) - msg = await self._closure( - query=query, - req_messages=req_messages, - use_model=model, - use_funcs=funcs, - extra_args=extra_args, - ) - return msg + msg = await self._closure( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + extra_args=extra_args, + ) + return msg except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: @@ -316,16 +311,16 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): req_messages.append(msg_dict) try: - if stream: - async for item in self._closure_stream( - query=query, - req_messages=req_messages, - use_model=model, - use_funcs=funcs, - stream=stream, - extra_args=extra_args, - ): - yield item + async for item in self._closure_stream( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + stream=stream, + extra_args=extra_args, + ): + yield item + print(item) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 31c7e119..79de89a4 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -102,7 +102,6 @@ class LocalAgentRunner(runner.RequestRunner): query.use_llm_model, req_messages, query.use_funcs, - is_stream, extra_args=query.use_llm_model.model_entity.extra_args, ) yield msg