diff --git a/controller/relay-text.go b/controller/relay-text.go index 50b45140..68552b53 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -202,12 +202,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { var completionTokens int switch relayMode { case RelayModeChatCompletions: - messages, err := textRequest.TextMessages() - if err != nil { - return errorWrapper(err, "parse_text_messages_failed", http.StatusBadRequest) + // first try to parse as text messages + if messages, err := textRequest.TextMessages(); err != nil { + // then try to parse as vision messages + if messages, err := textRequest.VisionMessages(); err != nil { + return errorWrapper(err, "parse_text_messages_failed", http.StatusBadRequest) + } else { + // vision message + if promptTokens, err = countVisonTokenMessages(messages, textRequest.Model); err != nil { + return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) + } + } + } else { + promptTokens = countTokenMessages(messages, textRequest.Model) } - - promptTokens = countTokenMessages(messages, textRequest.Model) case RelayModeCompletions: promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) case RelayModeModerations: diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 407d876b..0538aad7 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,15 +1,22 @@ package controller import ( + "bytes" + "encoding/base64" "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" + "image/jpeg" + "image/png" "io" + "math" "net/http" "one-api/common" "strconv" "strings" + + "github.com/Laisky/errors/v2" + "github.com/gin-gonic/gin" + "github.com/pkoukk/tiktoken-go" ) var stopFinishReason = "stop" @@ -65,6 +72,53 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } +// CountVisionImageToken count vision image tokens +// +// https://openai.com/pricing +func CountVisionImageToken(cnt []byte, resolution VisionImageResolution) (int, error) { + width, height, err := imageSize(cnt) + if err != nil { + return 0, errors.Wrap(err, "get image size") + } + + switch resolution { + case VisionImageResolutionLow: + return 85, nil // fixed price + case VisionImageResolutionHigh: + h := math.Ceil(float64(height) / 512) + w := math.Ceil(float64(width) / 512) + n := w * h + total := 85 + 170*n + return int(total), nil + default: + return 0, errors.Errorf("unsupport resolution %q", resolution) + } +} + +func imageSize(cnt []byte) (width, height int, err error) { + contentType := http.DetectContentType(cnt) + switch contentType { + case "image/jpeg", "image/jpg": + img, err := jpeg.Decode(bytes.NewReader(cnt)) + if err != nil { + return 0, 0, errors.Wrap(err, "decode jpeg") + } + + bounds := img.Bounds() + return bounds.Dx(), bounds.Dy(), nil + case "image/png": + img, err := png.Decode(bytes.NewReader(cnt)) + if err != nil { + return 0, 0, errors.Wrap(err, "decode png") + } + + bounds := img.Bounds() + return bounds.Dx(), bounds.Dy(), nil + default: + return 0, 0, errors.Errorf("unsupport image content type %q", contentType) + } +} + func countTokenMessages(messages []Message, model string) int { tokenEncoder := getTokenEncoder(model) // Reference: @@ -95,6 +149,51 @@ func countTokenMessages(messages []Message, model string) int { return tokenNum } +func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) { + tokenEncoder := getTokenEncoder(model) + // Reference: + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // https://github.com/pkoukk/tiktoken-go/issues/6 + // + // Every message follows <|start|>{role/name}\n{content}<|end|>\n + var tokensPerMessage int + var tokensPerName int + if model == "gpt-3.5-turbo-0301" { + tokensPerMessage = 4 + tokensPerName = -1 // If there's a name, the role is omitted + } else { + tokensPerMessage = 3 + tokensPerName = 1 + } + tokenNum := 0 + for _, message := range messages { + tokenNum += tokensPerMessage + switch message.Content.Type { + case OpenaiVisionMessageContentTypeText: + tokenNum += getTokenNum(tokenEncoder, message.Content.Text) + case OpenaiVisionMessageContentTypeImageUrl: + imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(message.Content.ImageUrl.URL, "data:image/jpeg;base64,")) + if err != nil { + return 0, errors.Wrap(err, "failed to decode base64 image") + } + + if imgtoken, err := CountVisionImageToken(imgblob, message.Content.ImageUrl.Detail); err != nil { + return 0, errors.Wrap(err, "failed to count vision image token") + } else { + tokenNum += imgtoken + } + } + + tokenNum += getTokenNum(tokenEncoder, message.Role) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } + } + tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> + return tokenNum, nil +} + func countTokenInput(input any, model string) int { switch input.(type) { case string: diff --git a/controller/relay.go b/controller/relay.go index 5d41d10b..0ca66190 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -105,7 +105,7 @@ func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) { if blob, err := json.Marshal(r.Messages); err != nil { return nil, errors.Wrap(err, "marshal messages failed") } else if err := json.Unmarshal(blob, &messages); err != nil { - return nil, errors.Wrap(err, "unmarshal messages failed") + return nil, errors.Wrapf(err, "unmarshal messages failed %q", string(blob)) } else { return messages, nil }