feat: Add token counting functionality to vision-related functions

- Add function `CountVisionImageToken` to count vision image tokens
- Modify function `imageSize` to handle different image types
- Add function `countVisonTokenMessages` to count tokens in vision messages
- Add logic to count tokens for different types of vision messages in `countVisonTokenMessages`
- Add tokens for role and name in `countVisonTokenMessages`
- Update total token count calculation in `countVisonTokenMessages` to include image tokens and message tokens
- Add constant values for tokens per message and tokens per name in `countVisonTokenMessages`
- Modify the error message on line 12 to include the JSON string that failed to unmarshal
This commit is contained in:
Laisky.Cai 2023-11-17 04:00:49 +00:00
parent d0c0b9b650
commit 08ca72184a
3 changed files with 115 additions and 8 deletions

View File

@ -202,12 +202,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
messages, err := textRequest.TextMessages()
if err != nil {
return errorWrapper(err, "parse_text_messages_failed", http.StatusBadRequest)
// first try to parse as text messages
if messages, err := textRequest.TextMessages(); err != nil {
// then try to parse as vision messages
if messages, err := textRequest.VisionMessages(); err != nil {
return errorWrapper(err, "parse_text_messages_failed", http.StatusBadRequest)
} else {
// vision message
if promptTokens, err = countVisonTokenMessages(messages, textRequest.Model); err != nil {
return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
}
} else {
promptTokens = countTokenMessages(messages, textRequest.Model)
}
promptTokens = countTokenMessages(messages, textRequest.Model)
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:

View File

@ -1,15 +1,22 @@
package controller
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"image/jpeg"
"image/png"
"io"
"math"
"net/http"
"one-api/common"
"strconv"
"strings"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
)
var stopFinishReason = "stop"
@ -65,6 +72,53 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
// CountVisionImageToken count vision image tokens
//
// https://openai.com/pricing
func CountVisionImageToken(cnt []byte, resolution VisionImageResolution) (int, error) {
width, height, err := imageSize(cnt)
if err != nil {
return 0, errors.Wrap(err, "get image size")
}
switch resolution {
case VisionImageResolutionLow:
return 85, nil // fixed price
case VisionImageResolutionHigh:
h := math.Ceil(float64(height) / 512)
w := math.Ceil(float64(width) / 512)
n := w * h
total := 85 + 170*n
return int(total), nil
default:
return 0, errors.Errorf("unsupport resolution %q", resolution)
}
}
func imageSize(cnt []byte) (width, height int, err error) {
contentType := http.DetectContentType(cnt)
switch contentType {
case "image/jpeg", "image/jpg":
img, err := jpeg.Decode(bytes.NewReader(cnt))
if err != nil {
return 0, 0, errors.Wrap(err, "decode jpeg")
}
bounds := img.Bounds()
return bounds.Dx(), bounds.Dy(), nil
case "image/png":
img, err := png.Decode(bytes.NewReader(cnt))
if err != nil {
return 0, 0, errors.Wrap(err, "decode png")
}
bounds := img.Bounds()
return bounds.Dx(), bounds.Dy(), nil
default:
return 0, 0, errors.Errorf("unsupport image content type %q", contentType)
}
}
func countTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
@ -95,6 +149,51 @@ func countTokenMessages(messages []Message, model string) int {
return tokenNum
}
func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
switch message.Content.Type {
case OpenaiVisionMessageContentTypeText:
tokenNum += getTokenNum(tokenEncoder, message.Content.Text)
case OpenaiVisionMessageContentTypeImageUrl:
imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(message.Content.ImageUrl.URL, "data:image/jpeg;base64,"))
if err != nil {
return 0, errors.Wrap(err, "failed to decode base64 image")
}
if imgtoken, err := CountVisionImageToken(imgblob, message.Content.ImageUrl.Detail); err != nil {
return 0, errors.Wrap(err, "failed to count vision image token")
} else {
tokenNum += imgtoken
}
}
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum, nil
}
func countTokenInput(input any, model string) int {
switch input.(type) {
case string:

View File

@ -105,7 +105,7 @@ func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) {
if blob, err := json.Marshal(r.Messages); err != nil {
return nil, errors.Wrap(err, "marshal messages failed")
} else if err := json.Unmarshal(blob, &messages); err != nil {
return nil, errors.Wrap(err, "unmarshal messages failed")
return nil, errors.Wrapf(err, "unmarshal messages failed %q", string(blob))
} else {
return messages, nil
}