feat: 使用tiktoken计算tokens数

This commit is contained in:
RockChinQ
2023-07-31 11:59:22 +08:00
parent 6d45327882
commit e29691efbd
6 changed files with 93 additions and 38 deletions
+1 -1
View File
@@ -242,7 +242,7 @@ cd QChatGPT
2. 安装依赖 2. 安装依赖
```bash ```bash
pip3 install requests yiri-mirai openai colorlog func_timeout dulwich Pillow nakuru-project-idk pip3 install requests yiri-mirai openai colorlog func_timeout dulwich Pillow nakuru-project-idk CallingGPT tiktoken
``` ```
3. 运行一次主程序,生成配置文件 3. 运行一次主程序,生成配置文件
+1 -1
View File
@@ -47,7 +47,7 @@ def init_db():
def ensure_dependencies(): def ensure_dependencies():
import pkg.utils.pkgmgr as pkgmgr import pkg.utils.pkgmgr as pkgmgr
pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "--upgrade", pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "tiktoken", "--upgrade",
"-i", "https://pypi.tuna.tsinghua.edu.cn/simple", "-i", "https://pypi.tuna.tsinghua.edu.cn/simple",
"--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) "--trusted-host", "pypi.tuna.tsinghua.edu.cn"])
+4 -4
View File
@@ -93,10 +93,10 @@ class ChatCompletionRequest(RequestBase):
if 'function_call' in choice0['message']: if 'function_call' in choice0['message']:
self.pending_func_call = choice0['message']['function_call'] self.pending_func_call = choice0['message']['function_call']
self.append_message( # self.append_message(
role="assistant", # role="assistant",
content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False) # content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False)
) # )
return { return {
"id": resp["id"], "id": resp["id"],
+69 -1
View File
@@ -7,6 +7,7 @@ Completion - text-davinci-003 等模型
""" """
import openai, logging, threading, asyncio import openai, logging, threading, asyncio
import openai.error as aiE import openai.error as aiE
import tiktoken
from pkg.openai.api.model import RequestBase from pkg.openai.api.model import RequestBase
from pkg.openai.api.completion import CompletionRequest from pkg.openai.api.completion import CompletionRequest
@@ -48,4 +49,71 @@ def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBa
return ChatCompletionRequest(model_name, messages, **args) return ChatCompletionRequest(model_name, messages, **args)
elif model_name in COMPLETION_MODELS: elif model_name in COMPLETION_MODELS:
return CompletionRequest(model_name, messages, **args) return CompletionRequest(model_name, messages, **args)
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name)) raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
def count_chat_completion_tokens(messages: list, model: str) -> int:
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return count_chat_completion_tokens(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_chat_completion_tokens(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""count_chat_completion_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def count_completion_tokens(messages: list, model: str) -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
text = ""
for message in messages:
text += message['role'] + message['content'] + "\n"
text += "assistant: "
return len(encoding.encode(text))
def count_tokens(messages: list, model: str):
if model in CHAT_COMPLETION_MODELS:
return count_chat_completion_tokens(messages, model)
elif model in COMPLETION_MODELS:
return count_completion_tokens(messages, model)
raise ValueError("不支持模型[{}],请检查配置文件".format(model))
+15 -29
View File
@@ -16,6 +16,8 @@ import pkg.utils.context
import pkg.plugin.host as plugin_host import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models import pkg.plugin.models as plugin_models
from pkg.openai.modelmgr import count_tokens
# 运行时保存的所有session # 运行时保存的所有session
sessions = {} sessions = {}
@@ -107,9 +109,6 @@ class Session:
prompt = [] prompt = []
"""使用list来保存会话中的回合""" """使用list来保存会话中的回合"""
token_counts = []
"""每个回合的token数量"""
default_prompt = [] default_prompt = []
"""本session的默认prompt""" """本session的默认prompt"""
@@ -215,12 +214,7 @@ class Session:
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
max_length = config.prompt_submit_length max_length = config.prompt_submit_length
prompts, counts = self.cut_out(text, max_length) prompts, _ = self.cut_out(text, max_length)
# 计算请求前的prompt数量
total_token_before_query = 0
for token_count in counts:
total_token_before_query += token_count
res_text = "" res_text = ""
@@ -281,8 +275,8 @@ class Session:
self.prompt += pending_msgs self.prompt += pending_msgs
# 向token_counts中添加本回合的token数量 # 向token_counts中添加本回合的token数量
self.token_counts.append(total_tokens-total_token_before_query) # self.token_counts.append(total_tokens-total_token_before_query)
logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts)) # logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
if self.just_switched_to_exist_session: if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False self.just_switched_to_exist_session = False
@@ -319,24 +313,19 @@ class Session:
# 包装目前的对话回合内容 # 包装目前的对话回合内容
changable_prompts = [] changable_prompts = []
changable_counts = []
# 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1
changable_index = len(self.prompt) - 1
token_count_index = len(self.token_counts) - 1
packed_tokens = 0 use_model = pkg.utils.context.get_config().completion_api_params['model']
while changable_index >= 0 and token_count_index >= 0: ptr = len(self.prompt) - 1
if packed_tokens + self.token_counts[token_count_index] > max_tokens:
# 直接从后向前扫描拼接,不管是否是整回合
while ptr >= 0:
if count_tokens(self.prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
break break
changable_prompts.insert(0, self.prompt[changable_index]) changable_prompts.insert(0, self.prompt[ptr])
changable_prompts.insert(0, self.prompt[changable_index - 1])
changable_counts.insert(0, self.token_counts[token_count_index])
packed_tokens += self.token_counts[token_count_index]
changable_index -= 2 ptr -= 1
token_count_index -= 1
# 将default_prompt和changable_prompts合并 # 将default_prompt和changable_prompts合并
result_prompt = self.default_prompt + changable_prompts result_prompt = self.default_prompt + changable_prompts
@@ -349,12 +338,9 @@ class Session:
} }
) )
logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4), logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
packed_tokens,
changable_counts,
self.token_counts))
return result_prompt, changable_counts return result_prompt, count_tokens(changable_prompts, use_model)
# 持久化session # 持久化session
def persistence(self): def persistence(self):
+3 -2
View File
@@ -2,10 +2,11 @@ requests~=2.31.0
openai~=0.27.8 openai~=0.27.8
dulwich~=0.21.5 dulwich~=0.21.5
colorlog~=6.6.0 colorlog~=6.6.0
yiri-mirai~=0.2.7 yiri-mirai
websockets websockets
urllib3~=1.26.10 urllib3~=1.26.10
func_timeout~=4.3.5 func_timeout~=4.3.5
Pillow Pillow
nakuru-project-idk nakuru-project-idk
CallingGPT CallingGPT
tiktoken