support gemini-pro-vision

This commit is contained in:
CaIon 2023-12-27 16:32:54 +08:00
parent 4036355fae
commit 14592f9758
6 changed files with 141 additions and 15 deletions

View File

@ -12,7 +12,7 @@ import (
"strings" "strings"
) )
func DecodeBase64ImageData(base64String string) (image.Config, error) { func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
// 去除base64数据的URL前缀如果有 // 去除base64数据的URL前缀如果有
if idx := strings.Index(base64String, ","); idx != -1 { if idx := strings.Index(base64String, ","); idx != -1 {
base64String = base64String[idx+1:] base64String = base64String[idx+1:]
@ -22,20 +22,51 @@ func DecodeBase64ImageData(base64String string) (image.Config, error) {
decodedData, err := base64.StdEncoding.DecodeString(base64String) decodedData, err := base64.StdEncoding.DecodeString(base64String)
if err != nil { if err != nil {
fmt.Println("Error: Failed to decode base64 string") fmt.Println("Error: Failed to decode base64 string")
return image.Config{}, err return image.Config{}, "", err
} }
// 创建一个bytes.Buffer用于存储解码后的数据 // 创建一个bytes.Buffer用于存储解码后的数据
reader := bytes.NewReader(decodedData) reader := bytes.NewReader(decodedData)
config, err := getImageConfig(reader) config, format, err := getImageConfig(reader)
return config, err return config, format, err
} }
func DecodeUrlImageData(imageUrl string) (image.Config, error) { func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := http.Get(imageUrl) response, err := http.Get(imageUrl)
if err != nil { if err != nil {
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, err return image.Config{}, "", err
} }
// 限制读取的字节数,防止下载整个图片 // 限制读取的字节数,防止下载整个图片
@ -45,14 +76,14 @@ func DecodeUrlImageData(imageUrl string) (image.Config, error) {
// log.Fatal(err) // log.Fatal(err)
//} //}
//log.Printf("%x", data) //log.Printf("%x", data)
config, err := getImageConfig(limitReader) config, format, err := getImageConfig(limitReader)
response.Body.Close() response.Body.Close()
return config, err return config, format, err
} }
func getImageConfig(reader io.Reader) (image.Config, error) { func getImageConfig(reader io.Reader) (image.Config, string, error) {
// 读取图片的头部信息来获取图片尺寸 // 读取图片的头部信息来获取图片尺寸
config, _, err := image.DecodeConfig(reader) config, format, err := image.DecodeConfig(reader)
if err != nil { if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
SysLog(err.Error()) SysLog(err.Error())
@ -61,9 +92,10 @@ func getImageConfig(reader io.Reader) (image.Config, error) {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
SysLog(err.Error()) SysLog(err.Error())
} }
format = "webp"
} }
if err != nil { if err != nil {
return image.Config{}, err return image.Config{}, "", err
} }
return config, nil return config, format, nil
} }

View File

@ -62,6 +62,7 @@ var ModelRatio = map[string]float64{
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1, "PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens

View File

@ -432,6 +432,15 @@ func init() {
Root: "gemini-pro", Root: "gemini-pro",
Parent: nil, Parent: nil,
}, },
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{ {
Id: "chatglm_turbo", Id: "chatglm_turbo",
Object: "model", Object: "model",

View File

@ -12,6 +12,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
const (
GeminiVisionMaxImageNum = 16
)
type GeminiChatRequest struct { type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"` Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
@ -97,6 +101,31 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
}, },
}, },
} }
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model // there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" { if content.Role == "assistant" {
content.Role = "model" content.Role = "model"

View File

@ -76,12 +76,13 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
} }
var config image.Config var config image.Config
var err error var err error
var format string
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url)) common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
config, err = common.DecodeUrlImageData(imageUrl.Url) config, format, err = common.DecodeUrlImageData(imageUrl.Url)
} else { } else {
common.SysLog(fmt.Sprintf("decoding image")) common.SysLog(fmt.Sprintf("decoding image"))
config, err = common.DecodeBase64ImageData(imageUrl.Url) config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
} }
if err != nil { if err != nil {
return 0, err return 0, err
@ -101,7 +102,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
shortSide := config.Width shortSide := config.Width
otherSide := config.Height otherSide := config.Height
log.Printf("width: %d, height: %d", config.Width, config.Height) log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
// 缩放倍数 // 缩放倍数
scale := 1.0 scale := 1.0
if config.Height < shortSide { if config.Height < shortSide {

View File

@ -29,6 +29,60 @@ type MessageImageUrl struct {
Detail string `json:"detail"` Detail string `json:"detail"`
} }
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: stringContent,
})
return contentList
}
var arrayContent []json.RawMessage
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "auto"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
}
}
}
return contentList
}
return nil
}
const ( const (
RelayModeUnknown = iota RelayModeUnknown = iota
RelayModeChatCompletions RelayModeChatCompletions