diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 1d7b92f8..8d442fdb 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -8,7 +8,7 @@ from . import entities, operator, errors from ..config import manager as cfg_mgr # 引入所有算子以便注册 -from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama +from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model class CommandManager: diff --git a/pkg/command/operators/model.py b/pkg/command/operators/model.py new file mode 100644 index 00000000..692e2728 --- /dev/null +++ b/pkg/command/operators/model.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + +@operator.operator_class( + name="model", + help='显示和切换模型列表', + usage='!model\n!model show <模型名>\n!model set <模型名>', + privilege=2 +) +class ModelOperator(operator.CommandOperator): + """Model命令""" + + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + content = '模型列表:\n' + + model_list = self.ap.model_mgr.model_list + + for model in model_list: + content += f"\n名称: {model.name}\n" + content += f"请求器: {model.requester.name}\n" + + content += f"\n当前对话使用模型: {context.query.use_model.name}\n" + content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n" + + yield entities.CommandReturn(text=content.strip()) + + +@operator.operator_class( + name="show", + help='显示模型详情', + privilege=2, + parent_class=ModelOperator +) +class ModelShowOperator(operator.CommandOperator): + """Model Show命令""" + + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + model_name = context.crt_params[0] + + model = None + for _model in self.ap.model_mgr.model_list: + if model_name == _model.name: + model = _model + break + + if model is None: + yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}")) + else: + content = f"模型详情\n" + content += f"名称: {model.name}\n" + if model.model_name is not None: + content += f"请求模型名称: {model.model_name}\n" + content += f"请求器: {model.requester.name}\n" + content += f"密钥组: {model.token_mgr.provider}\n" + content += f"支持视觉: {model.vision_supported}\n" + content += f"支持工具: {model.tool_call_supported}\n" + + yield entities.CommandReturn(text=content.strip()) + +@operator.operator_class( + name="set", + help='设置默认使用模型', + privilege=2, + parent_class=ModelOperator +) +class ModelSetOperator(operator.CommandOperator): + """Model Set命令""" + + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + model_name = context.crt_params[0] + + model = None + for _model in self.ap.model_mgr.model_list: + if model_name == _model.name: + model = _model + break + + if model is None: + yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}")) + else: + self.ap.provider_cfg.data['model'] = model_name + await self.ap.provider_cfg.dump_config() + yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效") diff --git a/pkg/command/operators/ollama.py b/pkg/command/operators/ollama.py index db47918e..f5ed382d 100644 --- a/pkg/command/operators/ollama.py +++ b/pkg/command/operators/ollama.py @@ -2,9 +2,10 @@ from __future__ import annotations import json import typing +import traceback import ollama -from .. import operator, entities +from .. import operator, entities, errors @operator.operator_class( @@ -16,13 +17,16 @@ class OllamaOperator(operator.CommandOperator): async def execute( self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - content: str = '模型列表:\n' - model_list: list = ollama.list().get('models', []) - for model in model_list: - content += f"name: {model['name']}\n" - content += f"modified_at: {model['modified_at']}\n" - content += f"size: {bytes_to_mb(model['size'])}MB\n\n" - yield entities.CommandReturn(text=f"{content.strip()}") + try: + content: str = '模型列表:\n' + model_list: list = ollama.list().get('models', []) + for model in model_list: + content += f"名称: {model['name']}\n" + content += f"修改时间: {model['modified_at']}\n" + content += f"大小: {bytes_to_mb(model['size'])}MB\n\n" + yield entities.CommandReturn(text=f"{content.strip()}") + except ollama.ResponseError as e: + yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常")) def bytes_to_mb(num_bytes): @@ -53,11 +57,9 @@ class OllamaShowOperator(operator.CommandOperator): model_info[key] = ignore_show content += json.dumps(show, indent=4) + yield entities.CommandReturn(text=content.strip()) except ollama.ResponseError as e: - content = f"{e.error}" - - yield entities.CommandReturn(text=content.strip()) - + yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常")) @operator.operator_class( name="pull", @@ -69,9 +71,13 @@ class OllamaPullOperator(operator.CommandOperator): async def execute( self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - model_list: list = ollama.list().get('models', []) - if context.crt_params[0] in [model['name'] for model in model_list]: - yield entities.CommandReturn(text="模型已存在") + try: + model_list: list = ollama.list().get('models', []) + if context.crt_params[0] in [model['name'] for model in model_list]: + yield entities.CommandReturn(text="模型已存在") + return + except ollama.ResponseError as e: + yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常")) return on_progress: bool = False