mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-12 03:13:41 +08:00
feat: batch update with laisky's one-api
This commit is contained in:
@@ -1,16 +1,20 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
@@ -73,8 +77,10 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []model.Message, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// CountTokenMessages counts the number of tokens in a list of messages.
|
||||
func CountTokenMessages(ctx context.Context,
|
||||
messages []model.Message, actualModel string) int {
|
||||
tokenEncoder := getTokenEncoder(actualModel)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
@@ -82,47 +88,54 @@ func CountTokenMessages(messages []model.Message, model string) int {
|
||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
var tokensPerMessage int
|
||||
var tokensPerName int
|
||||
if model == "gpt-3.5-turbo-0301" {
|
||||
if actualModel == "gpt-3.5-turbo-0301" {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1 // If there's a name, the role is omitted
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
|
||||
tokenNum := 0
|
||||
var totalAudioTokens float64
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
switch v := message.Content.(type) {
|
||||
case string:
|
||||
tokenNum += getTokenNum(tokenEncoder, v)
|
||||
case []any:
|
||||
for _, it := range v {
|
||||
m := it.(map[string]any)
|
||||
switch m["type"] {
|
||||
case "text":
|
||||
if textValue, ok := m["text"]; ok {
|
||||
if textString, ok := textValue.(string); ok {
|
||||
tokenNum += getTokenNum(tokenEncoder, textString)
|
||||
}
|
||||
}
|
||||
case "image_url":
|
||||
imageUrl, ok := m["image_url"].(map[string]any)
|
||||
if ok {
|
||||
url := imageUrl["url"].(string)
|
||||
detail := ""
|
||||
if imageUrl["detail"] != nil {
|
||||
detail = imageUrl["detail"].(string)
|
||||
}
|
||||
imageTokens, err := countImageTokens(url, detail, model)
|
||||
if err != nil {
|
||||
logger.SysError("error counting image tokens: " + err.Error())
|
||||
} else {
|
||||
tokenNum += imageTokens
|
||||
}
|
||||
}
|
||||
contents := message.ParseContent()
|
||||
for _, content := range contents {
|
||||
switch content.Type {
|
||||
case model.ContentTypeText:
|
||||
if content.Text != nil {
|
||||
tokenNum += getTokenNum(tokenEncoder, *content.Text)
|
||||
}
|
||||
case model.ContentTypeImageURL:
|
||||
imageTokens, err := countImageTokens(
|
||||
content.ImageURL.Url,
|
||||
content.ImageURL.Detail,
|
||||
actualModel)
|
||||
if err != nil {
|
||||
logger.SysError("error counting image tokens: " + err.Error())
|
||||
} else {
|
||||
tokenNum += imageTokens
|
||||
}
|
||||
case model.ContentTypeInputAudio:
|
||||
audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data)
|
||||
if err != nil {
|
||||
logger.SysError("error decoding audio data: " + err.Error())
|
||||
}
|
||||
|
||||
audioTokens, err := helper.GetAudioTokens(ctx,
|
||||
bytes.NewReader(audioData),
|
||||
ratio.GetAudioPromptTokensPerSecond(actualModel))
|
||||
if err != nil {
|
||||
logger.SysError("error counting audio tokens: " + err.Error())
|
||||
} else {
|
||||
totalAudioTokens += audioTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenNum += int(math.Ceil(totalAudioTokens))
|
||||
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
|
||||
Reference in New Issue
Block a user