From 917edb34130c59eedcc185e0497f2d747821f500 Mon Sep 17 00:00:00 2001 From: RockChinQ Date: Fri, 17 Apr 2026 21:54:24 +0800 Subject: [PATCH] fix(ollama): implement invoke_llm_stream for OllamaChatCompletions --- .../modelmgr/requesters/ollamachat.py | 127 ++++++++++++++++-- 1 file changed, 119 insertions(+), 8 deletions(-) diff --git a/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py b/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py index 97361f89..e89a65fa 100644 --- a/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/src/langbot/pkg/provider/modelmgr/requesters/ollamachat.py @@ -104,6 +104,21 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): return ret_msg + async def _prepare_messages( + self, + messages: typing.List[provider_message.Message], + ) -> list[dict]: + """Prepare messages for Ollama API request.""" + req_messages: list = [] + for m in messages: + msg_dict: dict = m.dict(exclude_none=True) + content: Any = msg_dict.get('content') + if isinstance(content, list): + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): + msg_dict['content'] = '\n'.join(part['text'] for part in content) + req_messages.append(msg_dict) + return req_messages + async def invoke_llm( self, query: pipeline_query.Query, @@ -113,14 +128,7 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): extra_args: dict[str, typing.Any] = {}, remove_think: bool = False, ) -> provider_message.Message: - req_messages: list = [] - for m in messages: - msg_dict: dict = m.dict(exclude_none=True) - content: Any = msg_dict.get('content') - if isinstance(content, list): - if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): - msg_dict['content'] = '\n'.join(part['text'] for part in content) - req_messages.append(msg_dict) + req_messages = await self._prepare_messages(messages) try: return await self._closure( query=query, @@ -133,6 +141,109 @@ class OllamaChatCompletions(requester.ProviderAPIRequester): except asyncio.TimeoutError: raise errors.RequesterError('请求超时') + async def invoke_llm_stream( + self, + query: pipeline_query.Query, + model: requester.RuntimeLLMModel, + messages: typing.List[provider_message.Message], + funcs: typing.List[resource_tool.LLMTool] = None, + extra_args: dict[str, typing.Any] = {}, + remove_think: bool = False, + ) -> provider_message.MessageChunk: + req_messages = await self._prepare_messages(messages) + + try: + args = extra_args.copy() + args['model'] = model.model_entity.name + + # Process messages for Ollama format + msgs: list[dict] = req_messages.copy() + for msg in msgs: + if 'content' in msg and isinstance(msg['content'], list): + text_content: list = [] + image_urls: list = [] + for me in msg['content']: + if me['type'] == 'text': + text_content.append(me['text']) + elif me['type'] == 'image_base64': + image_urls.append(me['image_base64']) + msg['content'] = '\n'.join(text_content) + msg['images'] = [url.split(',')[1] for url in image_urls] + if 'tool_calls' in msg: + for tool_call in msg['tool_calls']: + tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments']) + args['messages'] = msgs + + args['tools'] = [] + if funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs) + if tools: + args['tools'] = tools + + args['stream'] = True + + chunk_idx = 0 + thinking_started = False + thinking_ended = False + role = 'assistant' + + async for chunk in await self.client.chat(**args): + message: ollama.Message = chunk.message + done = chunk.done + + delta_content = message.content or '' + reasoning_content = getattr(message, 'thinking', '') or '' + + # Handle reasoning/thinking content + if reasoning_content: + if remove_think: + chunk_idx += 1 + continue + + if not thinking_started: + thinking_started = True + delta_content = '\n' + reasoning_content + else: + delta_content = reasoning_content + elif thinking_started and not thinking_ended and delta_content: + thinking_ended = True + delta_content = '\n\n' + delta_content + + # Handle tool calls + tool_calls_data = None + if message.tool_calls: + tool_calls_data = [] + for tc in message.tool_calls: + tool_calls_data.append( + { + 'id': uuid.uuid4().hex, + 'type': 'function', + 'function': { + 'name': tc.function.name, + 'arguments': json.dumps(tc.function.arguments), + }, + } + ) + + # Skip empty first chunk + if chunk_idx == 0 and not delta_content and not reasoning_content and not tool_calls_data: + chunk_idx += 1 + continue + + chunk_data = { + 'role': role, + 'content': delta_content if delta_content else None, + 'tool_calls': tool_calls_data, + 'is_final': bool(done), + } + chunk_data = {k: v for k, v in chunk_data.items() if v is not None} + + yield provider_message.MessageChunk(**chunk_data) + chunk_idx += 1 + + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + async def invoke_embedding( self, model: requester.RuntimeEmbeddingModel,