diff --git a/common/image.go b/common/image.go index 9c1da5a..2d93719 100644 --- a/common/image.go +++ b/common/image.go @@ -12,7 +12,7 @@ import ( "strings" ) -func DecodeBase64ImageData(base64String string) (image.Config, error) { +func DecodeBase64ImageData(base64String string) (image.Config, string, error) { // 去除base64数据的URL前缀(如果有) if idx := strings.Index(base64String, ","); idx != -1 { base64String = base64String[idx+1:] @@ -22,20 +22,51 @@ func DecodeBase64ImageData(base64String string) (image.Config, error) { decodedData, err := base64.StdEncoding.DecodeString(base64String) if err != nil { fmt.Println("Error: Failed to decode base64 string") - return image.Config{}, err + return image.Config{}, "", err } // 创建一个bytes.Buffer用于存储解码后的数据 reader := bytes.NewReader(decodedData) - config, err := getImageConfig(reader) - return config, err + config, format, err := getImageConfig(reader) + return config, format, err } -func DecodeUrlImageData(imageUrl string) (image.Config, error) { +func IsImageUrl(url string) (bool, error) { + resp, err := http.Head(url) + if err != nil { + return false, err + } + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { + return false, nil + } + return true, nil +} + +func GetImageFromUrl(url string) (mimeType string, data string, err error) { + isImage, err := IsImageUrl(url) + if !isImage { + return + } + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + buffer := bytes.NewBuffer(nil) + _, err = buffer.ReadFrom(resp.Body) + if err != nil { + return + } + mimeType = resp.Header.Get("Content-Type") + data = base64.StdEncoding.EncodeToString(buffer.Bytes()) + return +} + +func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { response, err := http.Get(imageUrl) if err != nil { SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) - return image.Config{}, err + return image.Config{}, "", err } // 限制读取的字节数,防止下载整个图片 @@ -45,14 +76,14 @@ func DecodeUrlImageData(imageUrl string) (image.Config, error) { // log.Fatal(err) //} //log.Printf("%x", data) - config, err := getImageConfig(limitReader) + config, format, err := getImageConfig(limitReader) response.Body.Close() - return config, err + return config, format, err } -func getImageConfig(reader io.Reader) (image.Config, error) { +func getImageConfig(reader io.Reader) (image.Config, string, error) { // 读取图片的头部信息来获取图片尺寸 - config, _, err := image.DecodeConfig(reader) + config, format, err := image.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) SysLog(err.Error()) @@ -61,9 +92,10 @@ func getImageConfig(reader io.Reader) (image.Config, error) { err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) SysLog(err.Error()) } + format = "webp" } if err != nil { - return image.Config{}, err + return image.Config{}, "", err } - return config, nil + return config, format, nil } diff --git a/common/model-ratio.go b/common/model-ratio.go index 3016f06..b18ba0d 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -62,6 +62,7 @@ var ModelRatio = map[string]float64{ "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 9fa2132..84e7c4a 100644 --- a/controller/model.go +++ b/controller/model.go @@ -432,6 +432,15 @@ func init() { Root: "gemini-pro", Parent: nil, }, + { + Id: "gemini-pro-vision", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro-vision", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index d4ce18c..4454ff9 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -12,6 +12,10 @@ import ( "github.com/gin-gonic/gin" ) +const ( + GeminiVisionMaxImageNum = 16 +) + type GeminiChatRequest struct { Contents []GeminiChatContent `json:"contents"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` @@ -97,6 +101,31 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { }, }, } + openaiContent := message.ParseContent() + var parts []GeminiPart + imageNum := 0 + for _, part := range openaiContent { + + if part.Type == ContentTypeText { + parts = append(parts, GeminiPart{ + Text: part.Text, + }) + } else if part.Type == ContentTypeImageURL { + imageNum += 1 + if imageNum > GeminiVisionMaxImageNum { + continue + } + mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } + } + content.Parts = parts + // there's no assistant role in gemini and API shall vomit if Role is not user or model if content.Role == "assistant" { content.Role = "model" diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 3a556e7..7d41a0c 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -76,12 +76,13 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) { } var config image.Config var err error + var format string if strings.HasPrefix(imageUrl.Url, "http") { common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url)) - config, err = common.DecodeUrlImageData(imageUrl.Url) + config, format, err = common.DecodeUrlImageData(imageUrl.Url) } else { common.SysLog(fmt.Sprintf("decoding image")) - config, err = common.DecodeBase64ImageData(imageUrl.Url) + config, format, err = common.DecodeBase64ImageData(imageUrl.Url) } if err != nil { return 0, err @@ -101,7 +102,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) { shortSide := config.Width otherSide := config.Height - log.Printf("width: %d, height: %d", config.Width, config.Height) + log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height) // 缩放倍数 scale := 1.0 if config.Height < shortSide { diff --git a/controller/relay.go b/controller/relay.go index 9254fb2..4cce28d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -29,6 +29,60 @@ type MessageImageUrl struct { Detail string `json:"detail"` } +const ( + ContentTypeText = "text" + ContentTypeImageURL = "image_url" +) + +func (m Message) ParseContent() []MediaMessage { + var contentList []MediaMessage + var stringContent string + if err := json.Unmarshal(m.Content, &stringContent); err == nil { + contentList = append(contentList, MediaMessage{ + Type: ContentTypeText, + Text: stringContent, + }) + return contentList + } + var arrayContent []json.RawMessage + if err := json.Unmarshal(m.Content, &arrayContent); err == nil { + for _, contentItem := range arrayContent { + var contentMap map[string]any + if err := json.Unmarshal(contentItem, &contentMap); err != nil { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, MediaMessage{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + detail, ok := subObj["detail"] + if ok { + subObj["detail"] = detail.(string) + } else { + subObj["detail"] = "auto" + } + contentList = append(contentList, MediaMessage{ + Type: ContentTypeImageURL, + ImageUrl: MessageImageUrl{ + Url: subObj["url"].(string), + Detail: subObj["detail"].(string), + }, + }) + } + } + } + return contentList + } + + return nil +} + const ( RelayModeUnknown = iota RelayModeChatCompletions