fix: Add image support to Gemini relay

- Add support for getting base64-encoded images via openAI's image_url.
- Add `context` as a parameter for the function `LogError`.
- Handle the error from `image.GetImageFromUrl` by logging it.
- Convert the role to `user` if it is `system` and add a dummy model message to make Gemini happy.
This commit is contained in:
Laisky.Cai 2023-12-27 03:13:11 +00:00
parent 75cbfd7bb6
commit 671fe78e44
3 changed files with 17 additions and 3 deletions

View File

@ -44,6 +44,11 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
} }
func GetImageFromUrl(url string) (mimeType string, data string, err error) { func GetImageFromUrl(url string) (mimeType string, data string, err error) {
// openai's image_url support base64 encoded image
if strings.HasPrefix(url, "data:image/jpeg;base64,") {
return "image/jpeg", strings.TrimPrefix(url, "data:image/jpeg;base64,"), nil
}
isImage, err := IsImageUrl(url) isImage, err := IsImageUrl(url)
if !isImage { if !isImage {
return return

View File

@ -2,6 +2,7 @@ package controller
import ( import (
"bufio" "bufio"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -117,7 +118,12 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
if imageNum > GeminiVisionMaxImageNum { if imageNum > GeminiVisionMaxImageNum {
continue continue
} }
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.Url)
if err != nil {
common.LogError(context.TODO(),
fmt.Sprintf("get image from url %s got %+v", part.ImageURL.Url, err))
}
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &GeminiInlineData{
MimeType: mimeType, MimeType: mimeType,

View File

@ -262,7 +262,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
requestBody = c.Request.Body requestBody = c.Request.Body
} }
common.LogInfo(c.Request.Context(), fmt.Sprintf("convert to apitype %d", apiType)) common.LogInfo(c.Request.Context(), fmt.Sprintf(
"convert to apitype %d, channel_type %d, channel_id %d",
apiType, channelType, channelId))
switch apiType { switch apiType {
case APITypeClaude: case APITypeClaude:
claudeRequest := requestOpenAI2Claude(textRequest) claudeRequest := requestOpenAI2Claude(textRequest)
@ -300,6 +302,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
fmt.Println(">> convert request body to gemini: " + string(jsonStr)) // FIXME
// case APITypeZhipu: // case APITypeZhipu:
// zhipuRequest := requestOpenAI2Zhipu(textRequest) // zhipuRequest := requestOpenAI2Zhipu(textRequest)
// jsonStr, err := json.Marshal(zhipuRequest) // jsonStr, err := json.Marshal(zhipuRequest)
@ -431,7 +434,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
{ // more error info { // more error info
if reqdata, err := json.Marshal(textRequest); err != nil { if reqdata, err := json.Marshal(req.Body); err != nil {
fmt.Printf("[ERROR] marshal relay text error: %s\n", err.Error()) fmt.Printf("[ERROR] marshal relay text error: %s\n", err.Error())
} else { } else {
if respdata, err := io.ReadAll(resp.Body); err != nil { if respdata, err := io.ReadAll(resp.Body); err != nil {