refactor: Refactor code for improved efficiency and readability

- Refactored code in relay-utils.go
- Eliminated unused imports and redundant function
- Improved code readability with added comments
- Cleaned up by removing unnecessary commented-out code
This commit is contained in:
Laisky.Cai 2023-12-12 06:09:11 +00:00
parent 7239b3386a
commit c503a87c74

View File

@ -1,13 +1,9 @@
package controller package controller
import ( import (
"bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"image/jpeg"
"image/png"
"io" "io"
"math" "math"
"net/http" "net/http"
@ -78,49 +74,49 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
// CountVisionImageToken count vision image tokens // CountVisionImageToken count vision image tokens
// //
// https://openai.com/pricing // https://openai.com/pricing
func CountVisionImageToken(cnt []byte, resolution VisionImageResolution) (int, error) { // func CountVisionImageToken(cnt []byte, resolution VisionImageResolution) (int, error) {
width, height, err := imageSize(cnt) // width, height, err := imageSize(cnt)
if err != nil { // if err != nil {
return 0, errors.Wrap(err, "get image size") // return 0, errors.Wrap(err, "get image size")
} // }
switch resolution { // switch resolution {
case VisionImageResolutionLow: // case VisionImageResolutionLow:
return 85, nil // fixed price // return 85, nil // fixed price
case VisionImageResolutionHigh: // case VisionImageResolutionHigh:
h := math.Ceil(float64(height) / 512) // h := math.Ceil(float64(height) / 512)
w := math.Ceil(float64(width) / 512) // w := math.Ceil(float64(width) / 512)
n := w * h // n := w * h
total := 85 + 170*n // total := 85 + 170*n
return int(total), nil // return int(total), nil
default: // default:
return 0, errors.Errorf("unsupport resolution %q", resolution) // return 0, errors.Errorf("unsupport resolution %q", resolution)
} // }
} // }
func imageSize(cnt []byte) (width, height int, err error) { // func imageSize(cnt []byte) (width, height int, err error) {
contentType := http.DetectContentType(cnt) // contentType := http.DetectContentType(cnt)
switch contentType { // switch contentType {
case "image/jpeg", "image/jpg": // case "image/jpeg", "image/jpg":
img, err := jpeg.Decode(bytes.NewReader(cnt)) // img, err := jpeg.Decode(bytes.NewReader(cnt))
if err != nil { // if err != nil {
return 0, 0, errors.Wrap(err, "decode jpeg") // return 0, 0, errors.Wrap(err, "decode jpeg")
} // }
bounds := img.Bounds() // bounds := img.Bounds()
return bounds.Dx(), bounds.Dy(), nil // return bounds.Dx(), bounds.Dy(), nil
case "image/png": // case "image/png":
img, err := png.Decode(bytes.NewReader(cnt)) // img, err := png.Decode(bytes.NewReader(cnt))
if err != nil { // if err != nil {
return 0, 0, errors.Wrap(err, "decode png") // return 0, 0, errors.Wrap(err, "decode png")
} // }
bounds := img.Bounds() // bounds := img.Bounds()
return bounds.Dx(), bounds.Dy(), nil // return bounds.Dx(), bounds.Dy(), nil
default: // default:
return 0, 0, errors.Errorf("unsupport image content type %q", contentType) // return 0, 0, errors.Errorf("unsupport image content type %q", contentType)
} // }
} // }
func countTokenMessages(messages []Message, model string) int { func countTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
@ -178,52 +174,52 @@ func countTokenMessages(messages []Message, model string) int {
return tokenNum return tokenNum
} }
func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) { // func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) {
tokenEncoder := getTokenEncoder(model) // tokenEncoder := getTokenEncoder(model)
// Reference: // // Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb // // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6 // // https://github.com/pkoukk/tiktoken-go/issues/6
// // //
// Every message follows <|start|>{role/name}\n{content}<|end|>\n // // Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int // var tokensPerMessage int
var tokensPerName int // var tokensPerName int
if model == "gpt-3.5-turbo-0301" { // if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4 // tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted // tokensPerName = -1 // If there's a name, the role is omitted
} else { // } else {
tokensPerMessage = 3 // tokensPerMessage = 3
tokensPerName = 1 // tokensPerName = 1
} // }
tokenNum := 0 // tokenNum := 0
for _, message := range messages { // for _, message := range messages {
tokenNum += tokensPerMessage // tokenNum += tokensPerMessage
for _, cnt := range message.Content { // for _, cnt := range message.Content {
switch cnt.Type { // switch cnt.Type {
case OpenaiVisionMessageContentTypeText: // case OpenaiVisionMessageContentTypeText:
tokenNum += getTokenNum(tokenEncoder, cnt.Text) // tokenNum += getTokenNum(tokenEncoder, cnt.Text)
case OpenaiVisionMessageContentTypeImageUrl: // case OpenaiVisionMessageContentTypeImageUrl:
imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cnt.ImageUrl.URL, "data:image/jpeg;base64,")) // imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cnt.ImageUrl.URL, "data:image/jpeg;base64,"))
if err != nil { // if err != nil {
return 0, errors.Wrap(err, "failed to decode base64 image") // return 0, errors.Wrap(err, "failed to decode base64 image")
} // }
if imgtoken, err := CountVisionImageToken(imgblob, cnt.ImageUrl.Detail); err != nil { // if imgtoken, err := CountVisionImageToken(imgblob, cnt.ImageUrl.Detail); err != nil {
return 0, errors.Wrap(err, "failed to count vision image token") // return 0, errors.Wrap(err, "failed to count vision image token")
} else { // } else {
tokenNum += imgtoken // tokenNum += imgtoken
} // }
} // }
} // }
tokenNum += getTokenNum(tokenEncoder, message.Role) // tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil { // if message.Name != nil {
tokenNum += tokensPerName // tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name) // tokenNum += getTokenNum(tokenEncoder, *message.Name)
} // }
} // }
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> // tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum, nil // return tokenNum, nil
} // }
const ( const (
lowDetailCost = 85 lowDetailCost = 85