mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 09:36:37 +08:00
refactor: Refactor code for improved efficiency and readability
- Refactored code in relay-utils.go - Eliminated unused imports and redundant function - Improved code readability with added comments - Cleaned up by removing unnecessary commented-out code
This commit is contained in:
parent
7239b3386a
commit
c503a87c74
@ -1,13 +1,9 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"image/jpeg"
|
|
||||||
"image/png"
|
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -78,49 +74,49 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|||||||
// CountVisionImageToken count vision image tokens
|
// CountVisionImageToken count vision image tokens
|
||||||
//
|
//
|
||||||
// https://openai.com/pricing
|
// https://openai.com/pricing
|
||||||
func CountVisionImageToken(cnt []byte, resolution VisionImageResolution) (int, error) {
|
// func CountVisionImageToken(cnt []byte, resolution VisionImageResolution) (int, error) {
|
||||||
width, height, err := imageSize(cnt)
|
// width, height, err := imageSize(cnt)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return 0, errors.Wrap(err, "get image size")
|
// return 0, errors.Wrap(err, "get image size")
|
||||||
}
|
// }
|
||||||
|
|
||||||
switch resolution {
|
// switch resolution {
|
||||||
case VisionImageResolutionLow:
|
// case VisionImageResolutionLow:
|
||||||
return 85, nil // fixed price
|
// return 85, nil // fixed price
|
||||||
case VisionImageResolutionHigh:
|
// case VisionImageResolutionHigh:
|
||||||
h := math.Ceil(float64(height) / 512)
|
// h := math.Ceil(float64(height) / 512)
|
||||||
w := math.Ceil(float64(width) / 512)
|
// w := math.Ceil(float64(width) / 512)
|
||||||
n := w * h
|
// n := w * h
|
||||||
total := 85 + 170*n
|
// total := 85 + 170*n
|
||||||
return int(total), nil
|
// return int(total), nil
|
||||||
default:
|
// default:
|
||||||
return 0, errors.Errorf("unsupport resolution %q", resolution)
|
// return 0, errors.Errorf("unsupport resolution %q", resolution)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
func imageSize(cnt []byte) (width, height int, err error) {
|
// func imageSize(cnt []byte) (width, height int, err error) {
|
||||||
contentType := http.DetectContentType(cnt)
|
// contentType := http.DetectContentType(cnt)
|
||||||
switch contentType {
|
// switch contentType {
|
||||||
case "image/jpeg", "image/jpg":
|
// case "image/jpeg", "image/jpg":
|
||||||
img, err := jpeg.Decode(bytes.NewReader(cnt))
|
// img, err := jpeg.Decode(bytes.NewReader(cnt))
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return 0, 0, errors.Wrap(err, "decode jpeg")
|
// return 0, 0, errors.Wrap(err, "decode jpeg")
|
||||||
}
|
// }
|
||||||
|
|
||||||
bounds := img.Bounds()
|
// bounds := img.Bounds()
|
||||||
return bounds.Dx(), bounds.Dy(), nil
|
// return bounds.Dx(), bounds.Dy(), nil
|
||||||
case "image/png":
|
// case "image/png":
|
||||||
img, err := png.Decode(bytes.NewReader(cnt))
|
// img, err := png.Decode(bytes.NewReader(cnt))
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return 0, 0, errors.Wrap(err, "decode png")
|
// return 0, 0, errors.Wrap(err, "decode png")
|
||||||
}
|
// }
|
||||||
|
|
||||||
bounds := img.Bounds()
|
// bounds := img.Bounds()
|
||||||
return bounds.Dx(), bounds.Dy(), nil
|
// return bounds.Dx(), bounds.Dy(), nil
|
||||||
default:
|
// default:
|
||||||
return 0, 0, errors.Errorf("unsupport image content type %q", contentType)
|
// return 0, 0, errors.Errorf("unsupport image content type %q", contentType)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
func countTokenMessages(messages []Message, model string) int {
|
func countTokenMessages(messages []Message, model string) int {
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
@ -178,52 +174,52 @@ func countTokenMessages(messages []Message, model string) int {
|
|||||||
return tokenNum
|
return tokenNum
|
||||||
}
|
}
|
||||||
|
|
||||||
func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) {
|
// func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) {
|
||||||
tokenEncoder := getTokenEncoder(model)
|
// tokenEncoder := getTokenEncoder(model)
|
||||||
// Reference:
|
// // Reference:
|
||||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
||||||
//
|
// //
|
||||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
var tokensPerMessage int
|
// var tokensPerMessage int
|
||||||
var tokensPerName int
|
// var tokensPerName int
|
||||||
if model == "gpt-3.5-turbo-0301" {
|
// if model == "gpt-3.5-turbo-0301" {
|
||||||
tokensPerMessage = 4
|
// tokensPerMessage = 4
|
||||||
tokensPerName = -1 // If there's a name, the role is omitted
|
// tokensPerName = -1 // If there's a name, the role is omitted
|
||||||
} else {
|
// } else {
|
||||||
tokensPerMessage = 3
|
// tokensPerMessage = 3
|
||||||
tokensPerName = 1
|
// tokensPerName = 1
|
||||||
}
|
// }
|
||||||
tokenNum := 0
|
// tokenNum := 0
|
||||||
for _, message := range messages {
|
// for _, message := range messages {
|
||||||
tokenNum += tokensPerMessage
|
// tokenNum += tokensPerMessage
|
||||||
for _, cnt := range message.Content {
|
// for _, cnt := range message.Content {
|
||||||
switch cnt.Type {
|
// switch cnt.Type {
|
||||||
case OpenaiVisionMessageContentTypeText:
|
// case OpenaiVisionMessageContentTypeText:
|
||||||
tokenNum += getTokenNum(tokenEncoder, cnt.Text)
|
// tokenNum += getTokenNum(tokenEncoder, cnt.Text)
|
||||||
case OpenaiVisionMessageContentTypeImageUrl:
|
// case OpenaiVisionMessageContentTypeImageUrl:
|
||||||
imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cnt.ImageUrl.URL, "data:image/jpeg;base64,"))
|
// imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cnt.ImageUrl.URL, "data:image/jpeg;base64,"))
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to decode base64 image")
|
// return 0, errors.Wrap(err, "failed to decode base64 image")
|
||||||
}
|
// }
|
||||||
|
|
||||||
if imgtoken, err := CountVisionImageToken(imgblob, cnt.ImageUrl.Detail); err != nil {
|
// if imgtoken, err := CountVisionImageToken(imgblob, cnt.ImageUrl.Detail); err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to count vision image token")
|
// return 0, errors.Wrap(err, "failed to count vision image token")
|
||||||
} else {
|
// } else {
|
||||||
tokenNum += imgtoken
|
// tokenNum += imgtoken
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
// tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||||
if message.Name != nil {
|
// if message.Name != nil {
|
||||||
tokenNum += tokensPerName
|
// tokenNum += tokensPerName
|
||||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||||
return tokenNum, nil
|
// return tokenNum, nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
const (
|
const (
|
||||||
lowDetailCost = 85
|
lowDetailCost = 85
|
||||||
|
Loading…
Reference in New Issue
Block a user