diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 93ed8f8c..1d7b92f8 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -8,7 +8,7 @@ from . import entities, operator, errors from ..config import manager as cfg_mgr # 引入所有算子以便注册 -from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update +from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama class CommandManager: diff --git a/pkg/command/operators/ollama.py b/pkg/command/operators/ollama.py new file mode 100644 index 00000000..db47918e --- /dev/null +++ b/pkg/command/operators/ollama.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import json +import typing + +import ollama +from .. import operator, entities + + +@operator.operator_class( + name="ollama", + help="ollama平台操作", + usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>" +) +class OllamaOperator(operator.CommandOperator): + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + content: str = '模型列表:\n' + model_list: list = ollama.list().get('models', []) + for model in model_list: + content += f"name: {model['name']}\n" + content += f"modified_at: {model['modified_at']}\n" + content += f"size: {bytes_to_mb(model['size'])}MB\n\n" + yield entities.CommandReturn(text=f"{content.strip()}") + + +def bytes_to_mb(num_bytes): + mb: float = num_bytes / 1024 / 1024 + return format(mb, '.2f') + + +@operator.operator_class( + name="show", + help="ollama模型详情", + privilege=2, + parent_class=OllamaOperator +) +class OllamaShowOperator(operator.CommandOperator): + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + content: str = '模型详情:\n' + try: + show: dict = ollama.show(model=context.crt_params[0]) + model_info: dict = show.get('model_info', {}) + ignore_show: str = 'too long to show...' + + for key in ['license', 'modelfile']: + show[key] = ignore_show + + for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']: + model_info[key] = ignore_show + + content += json.dumps(show, indent=4) + except ollama.ResponseError as e: + content = f"{e.error}" + + yield entities.CommandReturn(text=content.strip()) + + +@operator.operator_class( + name="pull", + help="ollama模型拉取", + privilege=2, + parent_class=OllamaOperator +) +class OllamaPullOperator(operator.CommandOperator): + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + model_list: list = ollama.list().get('models', []) + if context.crt_params[0] in [model['name'] for model in model_list]: + yield entities.CommandReturn(text="模型已存在") + return + + on_progress: bool = False + progress_count: int = 0 + try: + for resp in ollama.pull(model=context.crt_params[0], stream=True): + total: typing.Any = resp.get('total') + if not on_progress: + if total is not None: + on_progress = True + yield entities.CommandReturn(text=resp.get('status')) + else: + if total is None: + on_progress = False + + completed: typing.Any = resp.get('completed') + if isinstance(completed, int) and isinstance(total, int): + percentage_completed = (completed / total) * 100 + if percentage_completed > progress_count: + progress_count += 10 + yield entities.CommandReturn( + text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)") + except ollama.ResponseError as e: + yield entities.CommandReturn(text=f"拉取失败: {e.error}") + + +@operator.operator_class( + name="del", + help="ollama模型删除", + privilege=2, + parent_class=OllamaOperator +) +class OllamaDelOperator(operator.CommandOperator): + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + try: + ret: str = ollama.delete(model=context.crt_params[0])['status'] + except ollama.ResponseError as e: + ret = f"{e.error}" + yield entities.CommandReturn(text=ret) diff --git a/pkg/config/migrations/m010_ollama_requester_config.py b/pkg/config/migrations/m010_ollama_requester_config.py new file mode 100644 index 00000000..56e49663 --- /dev/null +++ b/pkg/config/migrations/m010_ollama_requester_config.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("ollama-requester-config", 10) +class MsgTruncatorConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return 'ollama-chat' not in self.ap.provider_cfg.data['requester'] + + async def run(self): + """执行迁移""" + + self.ap.provider_cfg.data['requester']['ollama-chat'] = { + "base-url": "http://127.0.0.1:11434", + "args": {}, + "timeout": 600 + } + + await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index 2ad1e974..862d90af 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -6,6 +6,7 @@ from .. import stage, app from ...config import migration from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg +from ...config.migrations import m010_ollama_requester_config @stage.stage_class("MigrationStage") diff --git a/pkg/provider/modelmgr/apis/ollamachat.py b/pkg/provider/modelmgr/apis/ollamachat.py new file mode 100644 index 00000000..88edfe7b --- /dev/null +++ b/pkg/provider/modelmgr/apis/ollamachat.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import asyncio +import os +import typing +from typing import Union, Mapping, Any, AsyncIterator + +import async_lru +import ollama + +from .. import api, entities, errors +from ... import entities as llm_entities +from ...tools import entities as tools_entities +from ....core import app +from ....utils import image + +REQUESTER_NAME: str = "ollama-chat" + + +@api.requester_class(REQUESTER_NAME) +class OllamaChatCompletions(api.LLMAPIRequester): + """Ollama平台 ChatCompletion API请求器""" + client: ollama.AsyncClient + request_cfg: dict + + def __init__(self, ap: app.Application): + super().__init__(ap) + self.ap = ap + self.request_cfg = self.ap.provider_cfg.data['requester'][REQUESTER_NAME] + + async def initialize(self): + os.environ['OLLAMA_HOST'] = self.request_cfg['base-url'] + self.client = ollama.AsyncClient( + timeout=self.request_cfg['timeout'] + ) + + async def _req(self, + args: dict, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + return await self.client.chat( + **args + ) + + async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo, + user_funcs: list[tools_entities.LLMFunction] = None) -> ( + llm_entities.Message): + args: Any = self.request_cfg['args'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + messages: list[dict] = req_messages.copy() + for msg in messages: + if 'content' in msg and isinstance(msg["content"], list): + text_content: list = [] + image_urls: list = [] + for me in msg["content"]: + if me["type"] == "text": + text_content.append(me["text"]) + elif me["type"] == "image_url": + image_url = await self.get_base64_str(me["image_url"]['url']) + image_urls.append(image_url) + msg["content"] = "\n".join(text_content) + msg["images"] = [url.split(',')[1] for url in image_urls] + args["messages"] = messages + + resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args) + message: llm_entities.Message = await self._make_msg(resp) + return message + + async def _make_msg( + self, + chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message: + message: Any = chat_completions.pop('message', None) + if message is None: + raise ValueError("chat_completions must contain a 'message' field") + + message.update(chat_completions) + ret_msg: llm_entities.Message = llm_entities.Message(**message) + return ret_msg + + async def call( + self, + model: entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + req_messages: list = [] + for m in messages: + msg_dict: dict = m.dict(exclude_none=True) + content: Any = msg_dict.get("content") + if isinstance(content, list): + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): + msg_dict["content"] = "\n".join(part["text"] for part in content) + req_messages.append(msg_dict) + try: + return await self._closure(req_messages, model) + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + + @async_lru.alru_cache(maxsize=128) + async def get_base64_str( + self, + original_url: str, + ) -> str: + base64_image: str = await image.qq_image_url_to_base64(original_url) + return f"data:image/jpeg;base64,{base64_image}" diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 79e467a5..cf782302 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -6,7 +6,7 @@ from . import entities from ...core import app from . import token, api -from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl +from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" diff --git a/requirements.txt b/requirements.txt index 44bc285d..1b554d29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ pydantic websockets urllib3 psutil -async-lru \ No newline at end of file +async-lru +ollama \ No newline at end of file diff --git a/templates/provider.json b/templates/provider.json index 309fb827..32878fe9 100644 --- a/templates/provider.json +++ b/templates/provider.json @@ -37,6 +37,11 @@ "base-url": "https://api.deepseek.com", "args": {}, "timeout": 120 + }, + "ollama-chat": { + "base-url": "http://127.0.0.1:11434", + "args": {}, + "timeout": 600 } }, "model": "gpt-3.5-turbo",