feat: batch update with laisky's one-api

This commit is contained in:
Laisky.Cai
2025-03-20 10:43:14 +00:00
parent 761ee32d19
commit b2d6aa783b
35 changed files with 479 additions and 108 deletions

View File

@@ -1,16 +1,20 @@
package openai
import (
"errors"
"bytes"
"context"
"encoding/base64"
"fmt"
"math"
"strings"
"github.com/pkg/errors"
"github.com/pkoukk/tiktoken-go"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/billing/ratio"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -73,8 +77,10 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
func CountTokenMessages(messages []model.Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// CountTokenMessages counts the number of tokens in a list of messages.
func CountTokenMessages(ctx context.Context,
messages []model.Message, actualModel string) int {
tokenEncoder := getTokenEncoder(actualModel)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
@@ -82,47 +88,54 @@ func CountTokenMessages(messages []model.Message, model string) int {
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
if actualModel == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
var totalAudioTokens float64
for _, message := range messages {
tokenNum += tokensPerMessage
switch v := message.Content.(type) {
case string:
tokenNum += getTokenNum(tokenEncoder, v)
case []any:
for _, it := range v {
m := it.(map[string]any)
switch m["type"] {
case "text":
if textValue, ok := m["text"]; ok {
if textString, ok := textValue.(string); ok {
tokenNum += getTokenNum(tokenEncoder, textString)
}
}
case "image_url":
imageUrl, ok := m["image_url"].(map[string]any)
if ok {
url := imageUrl["url"].(string)
detail := ""
if imageUrl["detail"] != nil {
detail = imageUrl["detail"].(string)
}
imageTokens, err := countImageTokens(url, detail, model)
if err != nil {
logger.SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
}
contents := message.ParseContent()
for _, content := range contents {
switch content.Type {
case model.ContentTypeText:
if content.Text != nil {
tokenNum += getTokenNum(tokenEncoder, *content.Text)
}
case model.ContentTypeImageURL:
imageTokens, err := countImageTokens(
content.ImageURL.Url,
content.ImageURL.Detail,
actualModel)
if err != nil {
logger.SysError("error counting image tokens: " + err.Error())
} else {
tokenNum += imageTokens
}
case model.ContentTypeInputAudio:
audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data)
if err != nil {
logger.SysError("error decoding audio data: " + err.Error())
}
audioTokens, err := helper.GetAudioTokens(ctx,
bytes.NewReader(audioData),
ratio.GetAudioPromptTokensPerSecond(actualModel))
if err != nil {
logger.SysError("error counting audio tokens: " + err.Error())
} else {
totalAudioTokens += audioTokens
}
}
}
tokenNum += int(math.Ceil(totalAudioTokens))
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName