mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-22 03:16:38 +08:00
feat: Add token counting functionality to vision-related functions
- Add function `CountVisionImageToken` to count vision image tokens - Modify function `imageSize` to handle different image types - Add function `countVisonTokenMessages` to count tokens in vision messages - Add logic to count tokens for different types of vision messages in `countVisonTokenMessages` - Add tokens for role and name in `countVisonTokenMessages` - Update total token count calculation in `countVisonTokenMessages` to include image tokens and message tokens - Add constant values for tokens per message and tokens per name in `countVisonTokenMessages` - Modify the error message on line 12 to include the JSON string that failed to unmarshal
This commit is contained in:
parent
d0c0b9b650
commit
08ca72184a
@ -202,12 +202,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
var completionTokens int
|
||||
switch relayMode {
|
||||
case RelayModeChatCompletions:
|
||||
messages, err := textRequest.TextMessages()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "parse_text_messages_failed", http.StatusBadRequest)
|
||||
// first try to parse as text messages
|
||||
if messages, err := textRequest.TextMessages(); err != nil {
|
||||
// then try to parse as vision messages
|
||||
if messages, err := textRequest.VisionMessages(); err != nil {
|
||||
return errorWrapper(err, "parse_text_messages_failed", http.StatusBadRequest)
|
||||
} else {
|
||||
// vision message
|
||||
if promptTokens, err = countVisonTokenMessages(messages, textRequest.Model); err != nil {
|
||||
return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
promptTokens = countTokenMessages(messages, textRequest.Model)
|
||||
}
|
||||
|
||||
promptTokens = countTokenMessages(messages, textRequest.Model)
|
||||
case RelayModeCompletions:
|
||||
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case RelayModeModerations:
|
||||
|
@ -1,15 +1,22 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
var stopFinishReason = "stop"
|
||||
@ -65,6 +72,53 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
// CountVisionImageToken count vision image tokens
|
||||
//
|
||||
// https://openai.com/pricing
|
||||
func CountVisionImageToken(cnt []byte, resolution VisionImageResolution) (int, error) {
|
||||
width, height, err := imageSize(cnt)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "get image size")
|
||||
}
|
||||
|
||||
switch resolution {
|
||||
case VisionImageResolutionLow:
|
||||
return 85, nil // fixed price
|
||||
case VisionImageResolutionHigh:
|
||||
h := math.Ceil(float64(height) / 512)
|
||||
w := math.Ceil(float64(width) / 512)
|
||||
n := w * h
|
||||
total := 85 + 170*n
|
||||
return int(total), nil
|
||||
default:
|
||||
return 0, errors.Errorf("unsupport resolution %q", resolution)
|
||||
}
|
||||
}
|
||||
|
||||
func imageSize(cnt []byte) (width, height int, err error) {
|
||||
contentType := http.DetectContentType(cnt)
|
||||
switch contentType {
|
||||
case "image/jpeg", "image/jpg":
|
||||
img, err := jpeg.Decode(bytes.NewReader(cnt))
|
||||
if err != nil {
|
||||
return 0, 0, errors.Wrap(err, "decode jpeg")
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
return bounds.Dx(), bounds.Dy(), nil
|
||||
case "image/png":
|
||||
img, err := png.Decode(bytes.NewReader(cnt))
|
||||
if err != nil {
|
||||
return 0, 0, errors.Wrap(err, "decode png")
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
return bounds.Dx(), bounds.Dy(), nil
|
||||
default:
|
||||
return 0, 0, errors.Errorf("unsupport image content type %q", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
func countTokenMessages(messages []Message, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
@ -95,6 +149,51 @@ func countTokenMessages(messages []Message, model string) int {
|
||||
return tokenNum
|
||||
}
|
||||
|
||||
func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
//
|
||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
var tokensPerMessage int
|
||||
var tokensPerName int
|
||||
if model == "gpt-3.5-turbo-0301" {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1 // If there's a name, the role is omitted
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
switch message.Content.Type {
|
||||
case OpenaiVisionMessageContentTypeText:
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Content.Text)
|
||||
case OpenaiVisionMessageContentTypeImageUrl:
|
||||
imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(message.Content.ImageUrl.URL, "data:image/jpeg;base64,"))
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to decode base64 image")
|
||||
}
|
||||
|
||||
if imgtoken, err := CountVisionImageToken(imgblob, message.Content.ImageUrl.Detail); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to count vision image token")
|
||||
} else {
|
||||
tokenNum += imgtoken
|
||||
}
|
||||
}
|
||||
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
}
|
||||
}
|
||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func countTokenInput(input any, model string) int {
|
||||
switch input.(type) {
|
||||
case string:
|
||||
|
@ -105,7 +105,7 @@ func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) {
|
||||
if blob, err := json.Marshal(r.Messages); err != nil {
|
||||
return nil, errors.Wrap(err, "marshal messages failed")
|
||||
} else if err := json.Unmarshal(blob, &messages); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal messages failed")
|
||||
return nil, errors.Wrapf(err, "unmarshal messages failed %q", string(blob))
|
||||
} else {
|
||||
return messages, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user