From e29691efbd7d6abc81036db20f47280cbde666d7 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Mon, 31 Jul 2023 11:59:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BD=BF=E7=94=A8tiktoken=E8=AE=A1?= =?UTF-8?q?=E7=AE=97tokens=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- main.py | 2 +- pkg/openai/api/chat_completion.py | 8 ++-- pkg/openai/modelmgr.py | 70 ++++++++++++++++++++++++++++++- pkg/openai/session.py | 44 +++++++------------ requirements.txt | 5 ++- 6 files changed, 93 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index a2f0ffbe..1bc93020 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,7 @@ cd QChatGPT 2. 安装依赖 ```bash -pip3 install requests yiri-mirai openai colorlog func_timeout dulwich Pillow nakuru-project-idk +pip3 install requests yiri-mirai openai colorlog func_timeout dulwich Pillow nakuru-project-idk CallingGPT tiktoken ``` 3. 运行一次主程序,生成配置文件 diff --git a/main.py b/main.py index 15d57010..45d97a56 100644 --- a/main.py +++ b/main.py @@ -47,7 +47,7 @@ def init_db(): def ensure_dependencies(): import pkg.utils.pkgmgr as pkgmgr - pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "--upgrade", + pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "tiktoken", "--upgrade", "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", "--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) diff --git a/pkg/openai/api/chat_completion.py b/pkg/openai/api/chat_completion.py index f9cab135..9b4b0c24 100644 --- a/pkg/openai/api/chat_completion.py +++ b/pkg/openai/api/chat_completion.py @@ -93,10 +93,10 @@ class ChatCompletionRequest(RequestBase): if 'function_call' in choice0['message']: self.pending_func_call = choice0['message']['function_call'] - self.append_message( - role="assistant", - content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False) - ) + # self.append_message( + # role="assistant", + # content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False) + # ) return { "id": resp["id"], diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 8e6c7f20..0b287752 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -7,6 +7,7 @@ Completion - text-davinci-003 等模型 """ import openai, logging, threading, asyncio import openai.error as aiE +import tiktoken from pkg.openai.api.model import RequestBase from pkg.openai.api.completion import CompletionRequest @@ -48,4 +49,71 @@ def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBa return ChatCompletionRequest(model_name, messages, **args) elif model_name in COMPLETION_MODELS: return CompletionRequest(model_name, messages, **args) - raise ValueError("不支持模型[{}],请检查配置文件".format(model_name)) \ No newline at end of file + raise ValueError("不支持模型[{}],请检查配置文件".format(model_name)) + + +def count_chat_completion_tokens(messages: list, model: str) -> int: + """Return the number of tokens used by a list of messages.""" + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + if model in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = -1 # if there's a name, the role is omitted + elif "gpt-3.5-turbo" in model: + # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + return count_chat_completion_tokens(messages, model="gpt-3.5-turbo-0613") + elif "gpt-4" in model: + # print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + return count_chat_completion_tokens(messages, model="gpt-4-0613") + else: + raise NotImplementedError( + f"""count_chat_completion_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" + ) + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens + + +def count_completion_tokens(messages: list, model: str) -> int: + + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + text = "" + + for message in messages: + text += message['role'] + message['content'] + "\n" + + text += "assistant: " + + return len(encoding.encode(text)) + + +def count_tokens(messages: list, model: str): + if model in CHAT_COMPLETION_MODELS: + return count_chat_completion_tokens(messages, model) + elif model in COMPLETION_MODELS: + return count_completion_tokens(messages, model) + raise ValueError("不支持模型[{}],请检查配置文件".format(model)) diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 9c718f50..179d4c7b 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -16,6 +16,8 @@ import pkg.utils.context import pkg.plugin.host as plugin_host import pkg.plugin.models as plugin_models +from pkg.openai.modelmgr import count_tokens + # 运行时保存的所有session sessions = {} @@ -107,9 +109,6 @@ class Session: prompt = [] """使用list来保存会话中的回合""" - token_counts = [] - """每个回合的token数量""" - default_prompt = [] """本session的默认prompt""" @@ -215,12 +214,7 @@ class Session: config = pkg.utils.context.get_config() max_length = config.prompt_submit_length - prompts, counts = self.cut_out(text, max_length) - - # 计算请求前的prompt数量 - total_token_before_query = 0 - for token_count in counts: - total_token_before_query += token_count + prompts, _ = self.cut_out(text, max_length) res_text = "" @@ -281,8 +275,8 @@ class Session: self.prompt += pending_msgs # 向token_counts中添加本回合的token数量 - self.token_counts.append(total_tokens-total_token_before_query) - logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts)) + # self.token_counts.append(total_tokens-total_token_before_query) + # logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts)) if self.just_switched_to_exist_session: self.just_switched_to_exist_session = False @@ -319,24 +313,19 @@ class Session: # 包装目前的对话回合内容 changable_prompts = [] - changable_counts = [] - # 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1 - changable_index = len(self.prompt) - 1 - token_count_index = len(self.token_counts) - 1 - packed_tokens = 0 + use_model = pkg.utils.context.get_config().completion_api_params['model'] - while changable_index >= 0 and token_count_index >= 0: - if packed_tokens + self.token_counts[token_count_index] > max_tokens: + ptr = len(self.prompt) - 1 + + # 直接从后向前扫描拼接,不管是否是整回合 + while ptr >= 0: + if count_tokens(self.prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens: break - changable_prompts.insert(0, self.prompt[changable_index]) - changable_prompts.insert(0, self.prompt[changable_index - 1]) - changable_counts.insert(0, self.token_counts[token_count_index]) - packed_tokens += self.token_counts[token_count_index] + changable_prompts.insert(0, self.prompt[ptr]) - changable_index -= 2 - token_count_index -= 1 + ptr -= 1 # 将default_prompt和changable_prompts合并 result_prompt = self.default_prompt + changable_prompts @@ -349,12 +338,9 @@ class Session: } ) - logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4), - packed_tokens, - changable_counts, - self.token_counts)) + logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4))) - return result_prompt, changable_counts + return result_prompt, count_tokens(changable_prompts, use_model) # 持久化session def persistence(self): diff --git a/requirements.txt b/requirements.txt index 60a072f8..9e2c30a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,11 @@ requests~=2.31.0 openai~=0.27.8 dulwich~=0.21.5 colorlog~=6.6.0 -yiri-mirai~=0.2.7 +yiri-mirai websockets urllib3~=1.26.10 func_timeout~=4.3.5 Pillow nakuru-project-idk -CallingGPT \ No newline at end of file +CallingGPT +tiktoken \ No newline at end of file