diff --git a/pkg/provider/requester/apis/chatcmpl.py b/pkg/provider/requester/apis/chatcmpl.py index a565009e..52b895cc 100644 --- a/pkg/provider/requester/apis/chatcmpl.py +++ b/pkg/provider/requester/apis/chatcmpl.py @@ -47,7 +47,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester): self.client.api_key = conversation.use_model.token_mgr.get_token() args = self.ap.cfg_mgr.data["completion_api_params"].copy() - args["model"] = conversation.use_model.name + args["model"] = conversation.use_model.name if conversation.use_model.model_name is None else conversation.use_model.model_name if conversation.use_model.tool_call_supported: tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation) diff --git a/pkg/provider/requester/entities.py b/pkg/provider/requester/entities.py index c003564f..d4c51d6f 100644 --- a/pkg/provider/requester/entities.py +++ b/pkg/provider/requester/entities.py @@ -13,7 +13,7 @@ class LLMModelInfo(pydantic.BaseModel): name: str - provider: str + model_name: typing.Optional[str] = None token_mgr: token.TokenManager @@ -23,5 +23,7 @@ class LLMModelInfo(pydantic.BaseModel): tool_call_supported: typing.Optional[bool] = False + max_tokens: typing.Optional[int] = 2048 + class Config: arbitrary_types_allowed = True diff --git a/pkg/provider/requester/modelmgr.py b/pkg/provider/requester/modelmgr.py index 40bf313e..6510c1de 100644 --- a/pkg/provider/requester/modelmgr.py +++ b/pkg/provider/requester/modelmgr.py @@ -33,13 +33,183 @@ class ModelManager: tiktoken_tokenizer = tiktoken.Tiktoken(self.ap) - self.model_list.append( + model_list = [ entities.LLMModelInfo( name="gpt-3.5-turbo", - provider="openai", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer, + max_tokens=4096 + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-1106", token_mgr=openai_token_mgr, requester=openai_chat_completion, tool_call_supported=True, tokenizer=tiktoken_tokenizer - ) - ) + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-16k", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-0613", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-16k-0613", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-3.5-turbo-0301", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-4-1106-preview", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-4-vision-preview", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-4", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-4-32k", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-4-0613", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-4-32k-0613", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-4-0314", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="gpt-4-32k-0314", + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=True, + tokenizer=tiktoken_tokenizer + ), + ] + + self.model_list.extend(model_list) + + one_api_model_list = [ + entities.LLMModelInfo( + name="OneAPI/SparkDesk", + model_name='SparkDesk', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="OneAPI/chatglm_pro", + model_name='chatglm_pro', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="OneAPI/chatglm_std", + model_name='chatglm_std', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="OneAPI/chatglm_lite", + model_name='chatglm_lite', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="OneAPI/qwen-v1", + model_name='qwen-v1', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="OneAPI/qwen-plus-v1", + model_name='qwen-plus-v1', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="OneAPI/ERNIE-Bot", + model_name='ERNIE-Bot', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="OneAPI/ERNIE-Bot-turbo", + model_name='ERNIE-Bot-turbo', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer + ), + entities.LLMModelInfo( + name="OneAPI/gemini-pro", + model_name='gemini-pro', + token_mgr=openai_token_mgr, + requester=openai_chat_completion, + tool_call_supported=False, + tokenizer=tiktoken_tokenizer + ), + ] + + self.model_list.extend(one_api_model_list)