From 68cdd163d30f02e03fdade10feb6006bf7128d31 Mon Sep 17 00:00:00 2001 From: Dong_master <2213070223@qq.com> Date: Fri, 4 Jul 2025 03:26:44 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=81=E5=BC=8F=E5=9F=BA=E6=9C=AC=E6=B5=81?= =?UTF-8?q?=E7=A8=8B=E5=B7=B2=E9=80=9A=E8=BF=87=E4=BF=AE=E6=94=B9=E4=BA=86?= =?UTF-8?q?yield=E5=92=8Creturn=E7=9A=84=E5=86=B2=E7=AA=81=E5=AF=BC?= =?UTF-8?q?=E8=87=B4=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/process/handlers/chat.py | 50 ++++-- pkg/pipeline/respback/respback.py | 9 +- pkg/platform/adapter.py | 4 + pkg/platform/sources/lark.py | 135 ++++++++++----- pkg/provider/entities.py | 4 +- pkg/provider/modelmgr/requester.py | 27 ++- pkg/provider/modelmgr/requesters/chatcmpl.py | 167 +++++++++++++++---- pkg/provider/runners/localagent.py | 44 ++--- 8 files changed, 323 insertions(+), 117 deletions(-) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 9b3e0cd5..3a5925cc 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -59,8 +59,11 @@ class ChatMessageHandler(handler.MessageHandler): query.user_message.content = event_ctx.event.alter text_length = 0 - - is_stream = query.adapter.is_stream_output_supported() + try: + is_stream = query.adapter.is_stream + except AttributeError: + is_stream = False + print(is_stream) try: for r in runner_module.preregistered_runners: @@ -70,31 +73,44 @@ class ChatMessageHandler(handler.MessageHandler): else: raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}') if is_stream: - async for results in runner.run(query): - async for result in results: + # async for results in runner.run(query): + async for result in runner.run(query): + print(result) + query.resp_messages.append(result) + print(result) - query.resp_messages.append(result) + self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') - self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}') + if result.content is not None: + text_length += len(result.content) - if result.content is not None: - text_length += len(result.content) - - # current_chain = platform_message.MessageChain([]) - # for msg in accumulated_messages: - # if msg.content is not None: - # current_chain.append(platform_message.Plain(msg.content)) - # query.resp_message_chain = [current_chain] - - yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + # for result in results: + # + # query.resp_messages.append(result) + # print(result) + # + # self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.content)}') + # + # if result.content is not None: + # text_length += len(result.content) + # + # # current_chain = platform_message.MessageChain([]) + # # for msg in accumulated_messages: + # # if msg.content is not None: + # # current_chain.append(platform_message.Plain(msg.content)) + # # query.resp_message_chain = [current_chain] + # + # yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) # query.resp_messages.append(results) # self.ap.logger.info(f'对话({query.query_id})响应') # yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: - + print("非流式") async for result in runner.run(query): query.resp_messages.append(result) + print(result) self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 4ac4e1e3..52714ce2 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -3,6 +3,7 @@ from __future__ import annotations import random import asyncio +from typing_inspection.typing_objects import is_final from ...platform.types import events as platform_events from ...platform.types import message as platform_message @@ -39,12 +40,16 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin = query.pipeline_config['output']['misc']['quote-origin'] has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) + print(has_chunks) if has_chunks and hasattr(query.adapter,'reply_message_chunk'): + is_final = [msg.is_final for msg in query.resp_messages][0] + print(is_final) await query.adapter.reply_message_chunk( message_source=query.message_event, - message_id=query.query_id, - message_generator=query.resp_message_chain[-1], + message_id=query.message_event.message_chain.message_id, + message=query.resp_message_chain[-1], quote_origin=quote_origin, + is_final=is_final, ) else: await query.adapter.reply_message( diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index 18403b75..3951326c 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -25,6 +25,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): logger: EventLogger + is_stream: bool + def __init__(self, config: dict, ap: app.Application, logger: EventLogger): """初始化适配器 @@ -67,6 +69,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): message_id: int, message: platform_message.MessageChain, quote_origin: bool = False, + is_final: bool = False, ): """回复消息(流式输出) Args: @@ -114,6 +117,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): async def is_stream_output_supported(self) -> bool: """是否支持流式输出""" + self.is_stream = False return False async def kill(self) -> bool: diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index 9c9b5605..af57d66c 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -18,6 +18,7 @@ import aiohttp import lark_oapi.ws.exception import quart from lark_oapi.api.im.v1 import * +from lark_oapi.api.cardkit.v1 import * from .. import adapter from ...core import app @@ -348,6 +349,8 @@ class LarkAdapter(adapter.MessagePlatformAdapter): card_id_dict: dict[str, str] + seq: int + def __init__(self, config: dict, ap: app.Application, logger: EventLogger): self.config = config self.ap = ap @@ -356,6 +359,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): self.listeners = {} self.message_id_to_card_id = {} self.card_id_dict = {} + self.seq = 0 @self.quart_app.route('/lark/callback', methods=['POST']) async def lark_callback(): @@ -401,54 +405,79 @@ class LarkAdapter(adapter.MessagePlatformAdapter): return {'code': 500, 'message': 'error'} - def is_stream_output_supported() -> bool: + async def is_stream_output_supported() -> bool: is_stream = False - if self.config.get("",None): + if self.config.get("enable-card-reply",None): is_stream = True + self.is_stream = is_stream return is_stream - async def create_card_id(): + async def create_card_id(message_id): try: - is_stream = is_stream_output_supported() + is_stream = await is_stream_output_supported() if is_stream: self.ap.logger.debug('飞书支持stream输出,创建卡片......') - card_id = '' - if self.card_id_dict: - card_id = [k for k,v in self.card_id_dict.items() if (v+datetime.timedelta(days=14))< datetime.datetime.now()][0] + # card_id = '' + # # if self.card_id_dict: + # # card_id = [k for k,v in self.card_id_dict.items() if (v+datetime.timedelta(days=14))< datetime.datetime.now()][0] + # + # if self.card_id_dict is None: + # # content = { + # # "type": "card_json", + # # "data": {"schema":"2.0","header":{"title":{"content":"bot","tag":"plain_text"}},"body":{"elements":[{"tag":"markdown","content":""}]}} + # # } + # card_data = {"schema":"2.0","header":{"title":{"content":"bot","tag":"plain_text"}}, + # "body":{"elements":[{"tag":"markdown","content":""}]},"config": {"streaming_mode": True, + # "streaming_config": {"print_strategy": "fast"}}} + # + # request: CreateCardRequest = CreateCardRequest.builder() \ + # .request_body( + # CreateCardRequestBody.builder() + # .type("card_json") + # .data(json.dumps(card_data)) \ + # .build() + # ).build() + # + # # 发起请求 + # response: CreateCardResponse = self.api_client.cardkit.v1.card.create(request) + # + # + # # 处理失败返回 + # if not response.success(): + # raise Exception( + # f"client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}") + # + # self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}') + # self.card_id_dict[response.data.card_id] = datetime.datetime.now() + # + # card_id = response.data.card_id + card_data = {"schema": "2.0", "header": {"title": {"content": "bot", "tag": "plain_text"}}, + "body": {"elements": [{"tag": "markdown", "content": "[思考中.....]","element_id":"markdown_1"}]}, + "config": {"streaming_mode": True, + "streaming_config": {"print_strategy": "fast"}}} - if self.card_id_dict is None or card_id == '': - # content = { - # "type": "card_json", - # "data": {"schema":"2.0","header":{"title":{"content":"bot","tag":"plain_text"}},"body":{"elements":[{"tag":"markdown","content":""}]}} - # } - card_data = {"schema":"2.0","header":{"title":{"content":"bot","tag":"plain_text"}}, - "body":{"elements":[{"tag":"markdown","content":""}]},"config": {"streaming_mode": True, - "streaming_config": {"print_strategy": "fast"}}} + request: CreateCardRequest = CreateCardRequest.builder() \ + .request_body( + CreateCardRequestBody.builder() + .type("card_json") + .data(json.dumps(card_data)) \ + .build() + ).build() - request: CreateCardRequest = ( - CreateCardRequest.builder() - .request_body( - CreateCardRequestBody.builder() - .type("card_json") - .data(json.dumps(card_data)) - .build() - ) - ) - # 发起请求 - response: CreateCardResponse = await self.api_client.im.v1.card.create(request) + # 发起请求 + response: CreateCardResponse = self.api_client.cardkit.v1.card.create(request) + # 处理失败返回 + if not response.success(): + raise Exception( + f"client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}") - # 处理失败返回 - if not response.success(): - raise Exception( - f"client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}") + self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}') + self.card_id_dict[message_id] = response.data.card_id - self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}') - self.card_id_dict[response.data.card_id] = datetime.datetime.now() - - card_id = response.data.card_id + card_id = response.data.card_id return card_id except Exception as e: @@ -458,10 +487,10 @@ class LarkAdapter(adapter.MessagePlatformAdapter): async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1): - if is_stream_output_supported(): + if await is_stream_output_supported(): self.ap.logger.debug('卡片回复模式开启') # 开启卡片回复模式. 这里可以实现飞书一发消息,马上创建卡片进行回复"思考中..." - card_id = await create_card_id() + card_id = await create_card_id(event.event.message.message_id) reply_message_id = await self.create_message_card(card_id, event.event.message.message_id) self.message_id_to_card_id[event.event.message.message_id] = (reply_message_id, time.time()) @@ -500,8 +529,8 @@ class LarkAdapter(adapter.MessagePlatformAdapter): # TODO 目前只支持卡片模板方式,且卡片变量一定是content,未来这块要做成可配置 # 发消息马上就会回复显示初始化的content信息,即思考中 content = { - 'type': 'template', - 'data': {'template_id': card_id, 'template_variable': {'content': 'Thinking...'}}, + 'type': 'card', + 'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}}, } request: ReplyMessageRequest = ( ReplyMessageRequest.builder() @@ -564,35 +593,49 @@ class LarkAdapter(adapter.MessagePlatformAdapter): async def reply_message_chunk( self, message_source: platform_events.MessageEvent, + message_id: str, message: platform_message.MessageChain, quote_origin: bool = False, + is_final: bool = False, ): """ 回复消息变成更新卡片消息 """ lark_message = await self.message_converter.yiri2target(message, self.api_client) + if not is_final: + self.seq += 1 + + + text_message = '' for ele in lark_message[0]: if ele['tag'] == 'text': text_message += ele['text'] elif ele['tag'] == 'md': text_message += ele['text'] + print(text_message) content = { - 'type': 'template', - 'data': {'template_id': self.config['card_template_id'], 'template_variable': {'content': text_message}}, + 'type': 'card_json', + 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}}, } - request: PatchMessageRequest = ( - PatchMessageRequest.builder() - .message_id(self.message_id_to_card_id[message_source.message_chain.message_id][0]) - .request_body(PatchMessageRequestBody.builder().content(json.dumps(content)).build()) + request: ContentCardElementRequest = ContentCardElementRequest.builder() \ + .card_id(self.card_id_dict[message_id]) \ + .element_id("markdown_1") \ + .request_body(ContentCardElementRequestBody.builder() + # .uuid("a0d69e20-1dd1-458b-k525-dfeca4015204") + .content(text_message) + .sequence(self.seq) + .build()) \ .build() - ) + if is_final: + self.seq = 0 # 发起请求 - response: PatchMessageResponse = self.api_client.im.v1.message.patch(request) + response: ContentCardElementResponse = self.api_client.cardkit.v1.card_element.content(request) + # 处理失败返回 if not response.success(): diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index a149fea3..e8037e68 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -140,12 +140,12 @@ class MessageChunk(pydantic.BaseModel): content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None """内容""" - # tool_calls: typing.Optional[list[ToolCall]] = None + tool_calls: typing.Optional[list[ToolCall]] = None """工具调用""" tool_call_id: typing.Optional[str] = None - tool_calls: typing.Optional[list[ToolCallChunk]] = None + # tool_calls: typing.Optional[list[ToolCallChunk]] = None is_final: bool = False diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 3e5e791f..49a28f56 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -62,7 +62,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): funcs: typing.List[tools_entities.LLMFunction] = None, stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: + ) -> llm_entities.Message: """调用API Args: @@ -72,6 +72,29 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: - llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: 返回消息对象 + llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 + """ + pass + + @abc.abstractmethod + async def invoke_llm_stream( + self, + query: core_entities.Query, + model: RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + stream: bool = False, + extra_args: dict[str, typing.Any] = {}, + ) -> llm_entities.MessageChunk: + """调用API + + Args: + model (RuntimeLLMModel): 使用的模型信息 + messages (typing.List[llm_entities.Message]): 消息对象列表 + funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. + extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. + + Returns: + llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 """ pass diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 22931611..f06041fc 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -38,6 +38,15 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): ) -> chat_completion.ChatCompletion: return await self.client.chat.completions.create(**args, extra_body=extra_body) + async def _req_stream( + self, + args: dict, + extra_body: dict = {}, + ) -> chat_completion.ChatCompletion: + + async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): + yield chunk + async def _make_msg( self, chat_completion: chat_completion.ChatCompletion, @@ -62,9 +71,19 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): self, chat_completion: chat_completion.ChatCompletion, ) -> llm_entities.MessageChunk: - choice = chat_completion.choices[0] - delta = choice.delta.model_dump() + + # 处理流式chunk和完整响应的差异 + # print(chat_completion.choices[0]) + if hasattr(chat_completion, 'choices'): + # 完整响应模式 + choice = chat_completion.choices[0] + delta = choice.delta.model_dump() if hasattr(choice, 'delta') else choice.message.model_dump() + else: + # 流式chunk模式 + delta = chat_completion.delta.model_dump() if hasattr(chat_completion, 'delta') else {} + # 确保 role 字段存在且不为 None + # print(delta) if 'role' not in delta or delta['role'] is None: delta['role'] = 'assistant' @@ -78,8 +97,8 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): message = llm_entities.MessageChunk(**delta) return message - - async def _closure( + + async def _closure_stream( self, query: core_entities.Query, req_messages: list[dict], @@ -87,7 +106,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): use_funcs: list[tools_entities.LLMFunction] = None, stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.Message: + ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: self.client.api_key = use_model.token_mgr.get_token() args = {} @@ -115,36 +134,76 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): if stream: current_content = '' - async for chunk in await self._req(args, extra_body=extra_args): + args["stream"] = True + async for chunk in self._req_stream(args, extra_body=extra_args): + # print(chunk) # 处理流式消息 - delta_message = await self._make_msg_chunk( - chat_completion=chunk, - ) + delta_message = await self._make_msg_chunk(chunk) if delta_message.content: current_content += delta_message.content + delta_message.content = current_content + print(current_content) delta_message.all_content = current_content - - # 检查是否为最后一个块 - if chunk.choices[0].finish_reason is not None: + + # # 检查是否为最后一个块 + # if chunk.finish_reason is not None: + # delta_message.is_final = True + # + # yield delta_message + # 检查结束标志 + chunk_choices = getattr(chunk, 'choices', None) + if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): delta_message.is_final = True - yield delta_message - return - - else: + yield delta_message + # return - # 非流式请求 - resp = await self._req(args, extra_body=extra_args) - # 处理请求结果 - # 发送请求 - resp = await self._req(args, extra_body=extra_args) + + async def _closure( + self, + query: core_entities.Query, + req_messages: list[dict], + use_model: requester.RuntimeLLMModel, + use_funcs: list[tools_entities.LLMFunction] = None, + stream: bool = False, + extra_args: dict[str, typing.Any] = {}, + ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: + self.client.api_key = use_model.token_mgr.get_token() - # 处理请求结果 - message = await self._make_msg(resp) + args = {} + args['model'] = use_model.model_entity.name - return message - + if use_funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args['tools'] = tools + + # 设置此次请求中的messages + messages = req_messages.copy() + + # 检查vision + for msg in messages: + if 'content' in msg and isinstance(msg['content'], list): + for me in msg['content']: + if me['type'] == 'image_base64': + me['image_url'] = {'url': me['image_base64']} + me['type'] = 'image_url' + del me['image_base64'] + + args['messages'] = messages + + + + # 发送请求 + + resp = await self._req(args, extra_body=extra_args) + # 处理请求结果 + message = await self._make_msg(resp) + + + return message @@ -171,8 +230,9 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): req_messages.append(msg_dict) try: + if stream: - async for item in self._closure( + async for item in self._closure_stream( query=query, req_messages=req_messages, use_model=model, @@ -180,16 +240,17 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): stream=stream, extra_args=extra_args, ): - yield item - return + return item else: - return await self._closure( + print(req_messages) + msg = await self._closure( query=query, req_messages=req_messages, use_model=model, use_funcs=funcs, extra_args=extra_args, ) + return msg except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: @@ -205,3 +266,51 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def invoke_llm_stream( + self, + query: core_entities.Query, + model: requester.RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + stream: bool = False, + extra_args: dict[str, typing.Any] = {}, + ) -> llm_entities.MessageChunk: + req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 + for m in messages: + msg_dict = m.dict(exclude_none=True) + content = msg_dict.get('content') + if isinstance(content, list): + # 检查 content 列表中是否每个部分都是文本 + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): + # 将所有文本部分合并为一个字符串 + msg_dict['content'] = '\n'.join(part['text'] for part in content) + req_messages.append(msg_dict) + + try: + if stream: + async for item in self._closure_stream( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + stream=stream, + extra_args=extra_args, + ): + yield item + + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + except openai.BadRequestError as e: + if 'context_length_exceeded' in e.message: + raise errors.RequesterError(f'上文过长,请重置会话: {e.message}') + else: + raise errors.RequesterError(f'请求参数错误: {e.message}') + except openai.AuthenticationError as e: + raise errors.RequesterError(f'无效的 api-key: {e.message}') + except openai.NotFoundError as e: + raise errors.RequesterError(f'请求路径错误: {e.message}') + except openai.RateLimitError as e: + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') + except openai.APIError as e: + raise errors.RequesterError(f'请求错误: {e.message}') \ No newline at end of file diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 02b2db16..da97e334 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -24,25 +24,30 @@ class LocalAgentRunner(runner.RequestRunner): pending_tool_calls = [] req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] - - is_stream = query.adapter.is_stream_output_supported() + try: + is_stream = query.adapter.is_stream + except AttributeError: + is_stream = False # while True: # pass if not is_stream: # 非流式输出,直接请求 + # print(123) msg = await query.use_llm_model.requester.invoke_llm( query, query.use_llm_model, req_messages, query.use_funcs, + is_stream, extra_args=query.use_llm_model.model_entity.extra_args, ) yield msg final_msg = msg + print(final_msg) else: # 流式输出,需要处理工具调用 tool_calls_map: dict[str, llm_entities.ToolCall] = {} - async for msg in await query.use_llm_model.requester.invoke_llm( + async for msg in query.use_llm_model.requester.invoke_llm_stream( query, query.use_llm_model, req_messages, @@ -51,20 +56,20 @@ class LocalAgentRunner(runner.RequestRunner): extra_args=query.use_llm_model.model_entity.extra_args, ): yield msg - if msg.tool_calls: - for tool_call in msg.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + # if msg.tool_calls: + # for tool_call in msg.tool_calls: + # if tool_call.id not in tool_calls_map: + # tool_calls_map[tool_call.id] = llm_entities.ToolCall( + # id=tool_call.id, + # type=tool_call.type, + # function=llm_entities.FunctionCall( + # name=tool_call.function.name if tool_call.function else '', + # arguments='' + # ), + # ) + # if tool_call.function and tool_call.function.arguments: + # # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + # tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments final_msg = llm_entities.Message( role=msg.role, content=msg.all_content, @@ -105,7 +110,7 @@ class LocalAgentRunner(runner.RequestRunner): if is_stream: tool_calls_map = {} - async for msg in await query.use_llm_model.requester.invoke_llm( + async for msg in await query.use_llm_model.requester.invoke_llm_stream( query, query.use_llm_model, req_messages, @@ -130,10 +135,11 @@ class LocalAgentRunner(runner.RequestRunner): tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments final_msg = llm_entities.Message( role=msg.role, - content=all_content, + content=msg.all_content, tool_calls=list(tool_calls_map.values()), ) else: + print("非流式") # 处理完所有调用,再次请求 msg = await query.use_llm_model.requester.invoke_llm( query,