refactor: Refactor code for improved efficiency and readability

- Refactored code in relay-utils.go
- Eliminated unused imports and redundant function
- Improved code readability with added comments
- Cleaned up by removing unnecessary commented-out code
This commit is contained in:
Laisky.Cai 2023-12-12 06:09:11 +00:00
parent 7239b3386a
commit c503a87c74

View File

@ -1,13 +1,9 @@
package controller
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"image/jpeg"
"image/png"
"io"
"math"
"net/http"
@ -78,49 +74,49 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
// CountVisionImageToken count vision image tokens
//
// https://openai.com/pricing
func CountVisionImageToken(cnt []byte, resolution VisionImageResolution) (int, error) {
width, height, err := imageSize(cnt)
if err != nil {
return 0, errors.Wrap(err, "get image size")
}
// func CountVisionImageToken(cnt []byte, resolution VisionImageResolution) (int, error) {
// width, height, err := imageSize(cnt)
// if err != nil {
// return 0, errors.Wrap(err, "get image size")
// }
switch resolution {
case VisionImageResolutionLow:
return 85, nil // fixed price
case VisionImageResolutionHigh:
h := math.Ceil(float64(height) / 512)
w := math.Ceil(float64(width) / 512)
n := w * h
total := 85 + 170*n
return int(total), nil
default:
return 0, errors.Errorf("unsupport resolution %q", resolution)
}
}
// switch resolution {
// case VisionImageResolutionLow:
// return 85, nil // fixed price
// case VisionImageResolutionHigh:
// h := math.Ceil(float64(height) / 512)
// w := math.Ceil(float64(width) / 512)
// n := w * h
// total := 85 + 170*n
// return int(total), nil
// default:
// return 0, errors.Errorf("unsupport resolution %q", resolution)
// }
// }
func imageSize(cnt []byte) (width, height int, err error) {
contentType := http.DetectContentType(cnt)
switch contentType {
case "image/jpeg", "image/jpg":
img, err := jpeg.Decode(bytes.NewReader(cnt))
if err != nil {
return 0, 0, errors.Wrap(err, "decode jpeg")
}
// func imageSize(cnt []byte) (width, height int, err error) {
// contentType := http.DetectContentType(cnt)
// switch contentType {
// case "image/jpeg", "image/jpg":
// img, err := jpeg.Decode(bytes.NewReader(cnt))
// if err != nil {
// return 0, 0, errors.Wrap(err, "decode jpeg")
// }
bounds := img.Bounds()
return bounds.Dx(), bounds.Dy(), nil
case "image/png":
img, err := png.Decode(bytes.NewReader(cnt))
if err != nil {
return 0, 0, errors.Wrap(err, "decode png")
}
// bounds := img.Bounds()
// return bounds.Dx(), bounds.Dy(), nil
// case "image/png":
// img, err := png.Decode(bytes.NewReader(cnt))
// if err != nil {
// return 0, 0, errors.Wrap(err, "decode png")
// }
bounds := img.Bounds()
return bounds.Dx(), bounds.Dy(), nil
default:
return 0, 0, errors.Errorf("unsupport image content type %q", contentType)
}
}
// bounds := img.Bounds()
// return bounds.Dx(), bounds.Dy(), nil
// default:
// return 0, 0, errors.Errorf("unsupport image content type %q", contentType)
// }
// }
func countTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model)
@ -178,52 +174,52 @@ func countTokenMessages(messages []Message, model string) int {
return tokenNum
}
func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
for _, cnt := range message.Content {
switch cnt.Type {
case OpenaiVisionMessageContentTypeText:
tokenNum += getTokenNum(tokenEncoder, cnt.Text)
case OpenaiVisionMessageContentTypeImageUrl:
imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cnt.ImageUrl.URL, "data:image/jpeg;base64,"))
if err != nil {
return 0, errors.Wrap(err, "failed to decode base64 image")
}
// func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) {
// tokenEncoder := getTokenEncoder(model)
// // Reference:
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// // https://github.com/pkoukk/tiktoken-go/issues/6
// //
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
// var tokensPerMessage int
// var tokensPerName int
// if model == "gpt-3.5-turbo-0301" {
// tokensPerMessage = 4
// tokensPerName = -1 // If there's a name, the role is omitted
// } else {
// tokensPerMessage = 3
// tokensPerName = 1
// }
// tokenNum := 0
// for _, message := range messages {
// tokenNum += tokensPerMessage
// for _, cnt := range message.Content {
// switch cnt.Type {
// case OpenaiVisionMessageContentTypeText:
// tokenNum += getTokenNum(tokenEncoder, cnt.Text)
// case OpenaiVisionMessageContentTypeImageUrl:
// imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cnt.ImageUrl.URL, "data:image/jpeg;base64,"))
// if err != nil {
// return 0, errors.Wrap(err, "failed to decode base64 image")
// }
if imgtoken, err := CountVisionImageToken(imgblob, cnt.ImageUrl.Detail); err != nil {
return 0, errors.Wrap(err, "failed to count vision image token")
} else {
tokenNum += imgtoken
}
}
}
// if imgtoken, err := CountVisionImageToken(imgblob, cnt.ImageUrl.Detail); err != nil {
// return 0, errors.Wrap(err, "failed to count vision image token")
// } else {
// tokenNum += imgtoken
// }
// }
// }
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum, nil
}
// tokenNum += getTokenNum(tokenEncoder, message.Role)
// if message.Name != nil {
// tokenNum += tokensPerName
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
// }
// }
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
// return tokenNum, nil
// }
const (
lowDetailCost = 85