mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: 使用tiktoken计算tokens数
This commit is contained in:
@@ -242,7 +242,7 @@ cd QChatGPT
|
||||
2. 安装依赖
|
||||
|
||||
```bash
|
||||
pip3 install requests yiri-mirai openai colorlog func_timeout dulwich Pillow nakuru-project-idk
|
||||
pip3 install requests yiri-mirai openai colorlog func_timeout dulwich Pillow nakuru-project-idk CallingGPT tiktoken
|
||||
```
|
||||
|
||||
3. 运行一次主程序,生成配置文件
|
||||
|
||||
2
main.py
2
main.py
@@ -47,7 +47,7 @@ def init_db():
|
||||
|
||||
def ensure_dependencies():
|
||||
import pkg.utils.pkgmgr as pkgmgr
|
||||
pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "--upgrade",
|
||||
pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "tiktoken", "--upgrade",
|
||||
"-i", "https://pypi.tuna.tsinghua.edu.cn/simple",
|
||||
"--trusted-host", "pypi.tuna.tsinghua.edu.cn"])
|
||||
|
||||
|
||||
@@ -93,10 +93,10 @@ class ChatCompletionRequest(RequestBase):
|
||||
if 'function_call' in choice0['message']:
|
||||
self.pending_func_call = choice0['message']['function_call']
|
||||
|
||||
self.append_message(
|
||||
role="assistant",
|
||||
content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False)
|
||||
)
|
||||
# self.append_message(
|
||||
# role="assistant",
|
||||
# content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False)
|
||||
# )
|
||||
|
||||
return {
|
||||
"id": resp["id"],
|
||||
|
||||
@@ -7,6 +7,7 @@ Completion - text-davinci-003 等模型
|
||||
"""
|
||||
import openai, logging, threading, asyncio
|
||||
import openai.error as aiE
|
||||
import tiktoken
|
||||
|
||||
from pkg.openai.api.model import RequestBase
|
||||
from pkg.openai.api.completion import CompletionRequest
|
||||
@@ -49,3 +50,70 @@ def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBa
|
||||
elif model_name in COMPLETION_MODELS:
|
||||
return CompletionRequest(model_name, messages, **args)
|
||||
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
|
||||
|
||||
|
||||
def count_chat_completion_tokens(messages: list, model: str) -> int:
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
return count_chat_completion_tokens(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model:
|
||||
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return count_chat_completion_tokens(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""count_chat_completion_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||
)
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
def count_completion_tokens(messages: list, model: str) -> int:
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
text = ""
|
||||
|
||||
for message in messages:
|
||||
text += message['role'] + message['content'] + "\n"
|
||||
|
||||
text += "assistant: "
|
||||
|
||||
return len(encoding.encode(text))
|
||||
|
||||
|
||||
def count_tokens(messages: list, model: str):
|
||||
if model in CHAT_COMPLETION_MODELS:
|
||||
return count_chat_completion_tokens(messages, model)
|
||||
elif model in COMPLETION_MODELS:
|
||||
return count_completion_tokens(messages, model)
|
||||
raise ValueError("不支持模型[{}],请检查配置文件".format(model))
|
||||
|
||||
@@ -16,6 +16,8 @@ import pkg.utils.context
|
||||
import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
|
||||
from pkg.openai.modelmgr import count_tokens
|
||||
|
||||
# 运行时保存的所有session
|
||||
sessions = {}
|
||||
|
||||
@@ -107,9 +109,6 @@ class Session:
|
||||
prompt = []
|
||||
"""使用list来保存会话中的回合"""
|
||||
|
||||
token_counts = []
|
||||
"""每个回合的token数量"""
|
||||
|
||||
default_prompt = []
|
||||
"""本session的默认prompt"""
|
||||
|
||||
@@ -215,12 +214,7 @@ class Session:
|
||||
config = pkg.utils.context.get_config()
|
||||
max_length = config.prompt_submit_length
|
||||
|
||||
prompts, counts = self.cut_out(text, max_length)
|
||||
|
||||
# 计算请求前的prompt数量
|
||||
total_token_before_query = 0
|
||||
for token_count in counts:
|
||||
total_token_before_query += token_count
|
||||
prompts, _ = self.cut_out(text, max_length)
|
||||
|
||||
res_text = ""
|
||||
|
||||
@@ -281,8 +275,8 @@ class Session:
|
||||
self.prompt += pending_msgs
|
||||
|
||||
# 向token_counts中添加本回合的token数量
|
||||
self.token_counts.append(total_tokens-total_token_before_query)
|
||||
logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
|
||||
# self.token_counts.append(total_tokens-total_token_before_query)
|
||||
# logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
|
||||
|
||||
if self.just_switched_to_exist_session:
|
||||
self.just_switched_to_exist_session = False
|
||||
@@ -319,24 +313,19 @@ class Session:
|
||||
|
||||
# 包装目前的对话回合内容
|
||||
changable_prompts = []
|
||||
changable_counts = []
|
||||
# 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1
|
||||
changable_index = len(self.prompt) - 1
|
||||
token_count_index = len(self.token_counts) - 1
|
||||
|
||||
packed_tokens = 0
|
||||
use_model = pkg.utils.context.get_config().completion_api_params['model']
|
||||
|
||||
while changable_index >= 0 and token_count_index >= 0:
|
||||
if packed_tokens + self.token_counts[token_count_index] > max_tokens:
|
||||
ptr = len(self.prompt) - 1
|
||||
|
||||
# 直接从后向前扫描拼接,不管是否是整回合
|
||||
while ptr >= 0:
|
||||
if count_tokens(self.prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
|
||||
break
|
||||
|
||||
changable_prompts.insert(0, self.prompt[changable_index])
|
||||
changable_prompts.insert(0, self.prompt[changable_index - 1])
|
||||
changable_counts.insert(0, self.token_counts[token_count_index])
|
||||
packed_tokens += self.token_counts[token_count_index]
|
||||
changable_prompts.insert(0, self.prompt[ptr])
|
||||
|
||||
changable_index -= 2
|
||||
token_count_index -= 1
|
||||
ptr -= 1
|
||||
|
||||
# 将default_prompt和changable_prompts合并
|
||||
result_prompt = self.default_prompt + changable_prompts
|
||||
@@ -349,12 +338,9 @@ class Session:
|
||||
}
|
||||
)
|
||||
|
||||
logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4),
|
||||
packed_tokens,
|
||||
changable_counts,
|
||||
self.token_counts))
|
||||
logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
|
||||
|
||||
return result_prompt, changable_counts
|
||||
return result_prompt, count_tokens(changable_prompts, use_model)
|
||||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
|
||||
@@ -2,10 +2,11 @@ requests~=2.31.0
|
||||
openai~=0.27.8
|
||||
dulwich~=0.21.5
|
||||
colorlog~=6.6.0
|
||||
yiri-mirai~=0.2.7
|
||||
yiri-mirai
|
||||
websockets
|
||||
urllib3~=1.26.10
|
||||
func_timeout~=4.3.5
|
||||
Pillow
|
||||
nakuru-project-idk
|
||||
CallingGPT
|
||||
tiktoken
|
||||
Reference in New Issue
Block a user