mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 07:54:19 +00:00
feat: 使用tiktoken计算tokens数
This commit is contained in:
@@ -242,7 +242,7 @@ cd QChatGPT
|
|||||||
2. 安装依赖
|
2. 安装依赖
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install requests yiri-mirai openai colorlog func_timeout dulwich Pillow nakuru-project-idk
|
pip3 install requests yiri-mirai openai colorlog func_timeout dulwich Pillow nakuru-project-idk CallingGPT tiktoken
|
||||||
```
|
```
|
||||||
|
|
||||||
3. 运行一次主程序,生成配置文件
|
3. 运行一次主程序,生成配置文件
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def init_db():
|
|||||||
|
|
||||||
def ensure_dependencies():
|
def ensure_dependencies():
|
||||||
import pkg.utils.pkgmgr as pkgmgr
|
import pkg.utils.pkgmgr as pkgmgr
|
||||||
pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "--upgrade",
|
pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "tiktoken", "--upgrade",
|
||||||
"-i", "https://pypi.tuna.tsinghua.edu.cn/simple",
|
"-i", "https://pypi.tuna.tsinghua.edu.cn/simple",
|
||||||
"--trusted-host", "pypi.tuna.tsinghua.edu.cn"])
|
"--trusted-host", "pypi.tuna.tsinghua.edu.cn"])
|
||||||
|
|
||||||
|
|||||||
@@ -93,10 +93,10 @@ class ChatCompletionRequest(RequestBase):
|
|||||||
if 'function_call' in choice0['message']:
|
if 'function_call' in choice0['message']:
|
||||||
self.pending_func_call = choice0['message']['function_call']
|
self.pending_func_call = choice0['message']['function_call']
|
||||||
|
|
||||||
self.append_message(
|
# self.append_message(
|
||||||
role="assistant",
|
# role="assistant",
|
||||||
content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False)
|
# content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False)
|
||||||
)
|
# )
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": resp["id"],
|
"id": resp["id"],
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ Completion - text-davinci-003 等模型
|
|||||||
"""
|
"""
|
||||||
import openai, logging, threading, asyncio
|
import openai, logging, threading, asyncio
|
||||||
import openai.error as aiE
|
import openai.error as aiE
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
from pkg.openai.api.model import RequestBase
|
from pkg.openai.api.model import RequestBase
|
||||||
from pkg.openai.api.completion import CompletionRequest
|
from pkg.openai.api.completion import CompletionRequest
|
||||||
@@ -49,3 +50,70 @@ def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBa
|
|||||||
elif model_name in COMPLETION_MODELS:
|
elif model_name in COMPLETION_MODELS:
|
||||||
return CompletionRequest(model_name, messages, **args)
|
return CompletionRequest(model_name, messages, **args)
|
||||||
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
|
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
|
||||||
|
|
||||||
|
|
||||||
|
def count_chat_completion_tokens(messages: list, model: str) -> int:
|
||||||
|
"""Return the number of tokens used by a list of messages."""
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
print("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
if model in {
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k-0613",
|
||||||
|
"gpt-4-0314",
|
||||||
|
"gpt-4-32k-0314",
|
||||||
|
"gpt-4-0613",
|
||||||
|
"gpt-4-32k-0613",
|
||||||
|
}:
|
||||||
|
tokens_per_message = 3
|
||||||
|
tokens_per_name = 1
|
||||||
|
elif model == "gpt-3.5-turbo-0301":
|
||||||
|
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||||
|
elif "gpt-3.5-turbo" in model:
|
||||||
|
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||||
|
return count_chat_completion_tokens(messages, model="gpt-3.5-turbo-0613")
|
||||||
|
elif "gpt-4" in model:
|
||||||
|
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||||
|
return count_chat_completion_tokens(messages, model="gpt-4-0613")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"""count_chat_completion_tokens() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||||
|
)
|
||||||
|
num_tokens = 0
|
||||||
|
for message in messages:
|
||||||
|
num_tokens += tokens_per_message
|
||||||
|
for key, value in message.items():
|
||||||
|
num_tokens += len(encoding.encode(value))
|
||||||
|
if key == "name":
|
||||||
|
num_tokens += tokens_per_name
|
||||||
|
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def count_completion_tokens(messages: list, model: str) -> int:
|
||||||
|
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
print("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
text += message['role'] + message['content'] + "\n"
|
||||||
|
|
||||||
|
text += "assistant: "
|
||||||
|
|
||||||
|
return len(encoding.encode(text))
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(messages: list, model: str):
|
||||||
|
if model in CHAT_COMPLETION_MODELS:
|
||||||
|
return count_chat_completion_tokens(messages, model)
|
||||||
|
elif model in COMPLETION_MODELS:
|
||||||
|
return count_completion_tokens(messages, model)
|
||||||
|
raise ValueError("不支持模型[{}],请检查配置文件".format(model))
|
||||||
|
|||||||
+15
-29
@@ -16,6 +16,8 @@ import pkg.utils.context
|
|||||||
import pkg.plugin.host as plugin_host
|
import pkg.plugin.host as plugin_host
|
||||||
import pkg.plugin.models as plugin_models
|
import pkg.plugin.models as plugin_models
|
||||||
|
|
||||||
|
from pkg.openai.modelmgr import count_tokens
|
||||||
|
|
||||||
# 运行时保存的所有session
|
# 运行时保存的所有session
|
||||||
sessions = {}
|
sessions = {}
|
||||||
|
|
||||||
@@ -107,9 +109,6 @@ class Session:
|
|||||||
prompt = []
|
prompt = []
|
||||||
"""使用list来保存会话中的回合"""
|
"""使用list来保存会话中的回合"""
|
||||||
|
|
||||||
token_counts = []
|
|
||||||
"""每个回合的token数量"""
|
|
||||||
|
|
||||||
default_prompt = []
|
default_prompt = []
|
||||||
"""本session的默认prompt"""
|
"""本session的默认prompt"""
|
||||||
|
|
||||||
@@ -215,12 +214,7 @@ class Session:
|
|||||||
config = pkg.utils.context.get_config()
|
config = pkg.utils.context.get_config()
|
||||||
max_length = config.prompt_submit_length
|
max_length = config.prompt_submit_length
|
||||||
|
|
||||||
prompts, counts = self.cut_out(text, max_length)
|
prompts, _ = self.cut_out(text, max_length)
|
||||||
|
|
||||||
# 计算请求前的prompt数量
|
|
||||||
total_token_before_query = 0
|
|
||||||
for token_count in counts:
|
|
||||||
total_token_before_query += token_count
|
|
||||||
|
|
||||||
res_text = ""
|
res_text = ""
|
||||||
|
|
||||||
@@ -281,8 +275,8 @@ class Session:
|
|||||||
self.prompt += pending_msgs
|
self.prompt += pending_msgs
|
||||||
|
|
||||||
# 向token_counts中添加本回合的token数量
|
# 向token_counts中添加本回合的token数量
|
||||||
self.token_counts.append(total_tokens-total_token_before_query)
|
# self.token_counts.append(total_tokens-total_token_before_query)
|
||||||
logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
|
# logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
|
||||||
|
|
||||||
if self.just_switched_to_exist_session:
|
if self.just_switched_to_exist_session:
|
||||||
self.just_switched_to_exist_session = False
|
self.just_switched_to_exist_session = False
|
||||||
@@ -319,24 +313,19 @@ class Session:
|
|||||||
|
|
||||||
# 包装目前的对话回合内容
|
# 包装目前的对话回合内容
|
||||||
changable_prompts = []
|
changable_prompts = []
|
||||||
changable_counts = []
|
|
||||||
# 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1
|
|
||||||
changable_index = len(self.prompt) - 1
|
|
||||||
token_count_index = len(self.token_counts) - 1
|
|
||||||
|
|
||||||
packed_tokens = 0
|
use_model = pkg.utils.context.get_config().completion_api_params['model']
|
||||||
|
|
||||||
while changable_index >= 0 and token_count_index >= 0:
|
ptr = len(self.prompt) - 1
|
||||||
if packed_tokens + self.token_counts[token_count_index] > max_tokens:
|
|
||||||
|
# 直接从后向前扫描拼接,不管是否是整回合
|
||||||
|
while ptr >= 0:
|
||||||
|
if count_tokens(self.prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
|
||||||
break
|
break
|
||||||
|
|
||||||
changable_prompts.insert(0, self.prompt[changable_index])
|
changable_prompts.insert(0, self.prompt[ptr])
|
||||||
changable_prompts.insert(0, self.prompt[changable_index - 1])
|
|
||||||
changable_counts.insert(0, self.token_counts[token_count_index])
|
|
||||||
packed_tokens += self.token_counts[token_count_index]
|
|
||||||
|
|
||||||
changable_index -= 2
|
ptr -= 1
|
||||||
token_count_index -= 1
|
|
||||||
|
|
||||||
# 将default_prompt和changable_prompts合并
|
# 将default_prompt和changable_prompts合并
|
||||||
result_prompt = self.default_prompt + changable_prompts
|
result_prompt = self.default_prompt + changable_prompts
|
||||||
@@ -349,12 +338,9 @@ class Session:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4),
|
logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
|
||||||
packed_tokens,
|
|
||||||
changable_counts,
|
|
||||||
self.token_counts))
|
|
||||||
|
|
||||||
return result_prompt, changable_counts
|
return result_prompt, count_tokens(changable_prompts, use_model)
|
||||||
|
|
||||||
# 持久化session
|
# 持久化session
|
||||||
def persistence(self):
|
def persistence(self):
|
||||||
|
|||||||
+2
-1
@@ -2,10 +2,11 @@ requests~=2.31.0
|
|||||||
openai~=0.27.8
|
openai~=0.27.8
|
||||||
dulwich~=0.21.5
|
dulwich~=0.21.5
|
||||||
colorlog~=6.6.0
|
colorlog~=6.6.0
|
||||||
yiri-mirai~=0.2.7
|
yiri-mirai
|
||||||
websockets
|
websockets
|
||||||
urllib3~=1.26.10
|
urllib3~=1.26.10
|
||||||
func_timeout~=4.3.5
|
func_timeout~=4.3.5
|
||||||
Pillow
|
Pillow
|
||||||
nakuru-project-idk
|
nakuru-project-idk
|
||||||
CallingGPT
|
CallingGPT
|
||||||
|
tiktoken
|
||||||
Reference in New Issue
Block a user