fix: enhance image URL validation and error handling

This commit is contained in:
Laisky.Cai 2025-04-01 11:27:32 +00:00
parent 747e84851e
commit 16ddc0af83
2 changed files with 117 additions and 3 deletions

View File

@ -14,7 +14,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/client"
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
) )
@ -26,6 +25,16 @@ func IsImageUrl(url string) (bool, error) {
if err != nil { if err != nil {
return false, errors.Wrapf(err, "failed to fetch image URL: %s", url) return false, errors.Wrapf(err, "failed to fetch image URL: %s", url)
} }
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// this file may not support HEAD method
resp, err = client.UserContentRequestHTTPClient.Get(url)
if err != nil {
return false, errors.Wrapf(err, "failed to fetch image URL: %s", url)
}
defer resp.Body.Close()
}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return false, errors.Errorf("failed to fetch image URL: %s, status code: %d", url, resp.StatusCode) return false, errors.Errorf("failed to fetch image URL: %s, status code: %d", url, resp.StatusCode)
@ -35,6 +44,13 @@ func IsImageUrl(url string) (bool, error) {
return false, errors.Errorf("image size should not exceed 10MB: %s, size: %d", url, resp.ContentLength) return false, errors.Errorf("image size should not exceed 10MB: %s, size: %d", url, resp.ContentLength)
} }
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
if !strings.HasPrefix(contentType, "image/") &&
!strings.Contains(contentType, "application/octet-stream") {
return false,
errors.Errorf("invalid content type: %s, expected image type", contentType)
}
return true, nil return true, nil
} }
@ -51,6 +67,7 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
return 0, 0, errors.Wrap(err, "failed to get image from URL") return 0, 0, errors.Wrap(err, "failed to get image from URL")
} }
defer resp.Body.Close() defer resp.Body.Close()
img, _, err := image.DecodeConfig(resp.Body) img, _, err := image.DecodeConfig(resp.Body)
if err != nil { if err != nil {
return 0, 0, errors.Wrap(err, "failed to decode image") return 0, 0, errors.Wrap(err, "failed to decode image")
@ -81,11 +98,20 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
return mimeType, data, errors.Wrap(err, "failed to get image from URL") return mimeType, data, errors.Wrap(err, "failed to get image from URL")
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return mimeType, data, errors.Errorf("failed to fetch image URL: %s, status code: %d", url, resp.StatusCode)
}
if resp.ContentLength > 10*1024*1024 {
return mimeType, data, errors.Errorf("image size should not exceed 10MB: %s, size: %d", url, resp.ContentLength)
}
buffer := bytes.NewBuffer(nil) buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body) _, err = buffer.ReadFrom(resp.Body)
if err != nil { if err != nil {
return mimeType, data, errors.Wrap(err, "failed to read image data from response") return mimeType, data, errors.Wrap(err, "failed to read image data from response")
} }
mimeType = resp.Header.Get("Content-Type") mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes()) data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return mimeType, data, nil return mimeType, data, nil

View File

@ -1,8 +1,8 @@
package image_test package image_test
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"github.com/songquanpeng/one-api/common/client"
"image" "image"
_ "image/gif" _ "image/gif"
_ "image/jpeg" _ "image/jpeg"
@ -13,8 +13,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/songquanpeng/one-api/common/client"
img "github.com/songquanpeng/one-api/common/image" img "github.com/songquanpeng/one-api/common/image"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
) )
@ -51,6 +51,8 @@ func TestMain(m *testing.M) {
} }
func TestDecode(t *testing.T) { func TestDecode(t *testing.T) {
t.Parallel()
// Bytes read: varies sometimes // Bytes read: varies sometimes
// jpeg: 1063892 // jpeg: 1063892
// png: 294462 // png: 294462
@ -96,6 +98,8 @@ func TestDecode(t *testing.T) {
} }
func TestBase64(t *testing.T) { func TestBase64(t *testing.T) {
t.Parallel()
// Bytes read: // Bytes read:
// jpeg: 1063892 // jpeg: 1063892
// png: 294462 // png: 294462
@ -149,6 +153,8 @@ func TestBase64(t *testing.T) {
} }
func TestGetImageSize(t *testing.T) { func TestGetImageSize(t *testing.T) {
t.Parallel()
for i, c := range cases { for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
width, height, err := img.GetImageSize(c.url) width, height, err := img.GetImageSize(c.url)
@ -160,6 +166,8 @@ func TestGetImageSize(t *testing.T) {
} }
func TestGetImageSizeFromBase64(t *testing.T) { func TestGetImageSizeFromBase64(t *testing.T) {
t.Parallel()
for i, c := range cases { for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url) resp, err := http.Get(c.url)
@ -175,3 +183,83 @@ func TestGetImageSizeFromBase64(t *testing.T) {
}) })
} }
} }
func TestGetImageFromUrl(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantMime string
wantErr bool
errMessage string
}{
{
name: "Valid JPEG URL",
input: cases[0].url, // Using the existing JPEG test case
wantMime: "image/jpeg",
wantErr: false,
},
{
name: "Valid PNG URL",
input: cases[1].url, // Using the existing PNG test case
wantMime: "image/png",
wantErr: false,
},
{
name: "Valid Data URL",
input: "",
wantMime: "image/png",
wantErr: false,
},
{
name: "Invalid URL",
input: "https://invalid.example.com/nonexistent.jpg",
wantErr: true,
errMessage: "failed to fetch image URL",
},
{
name: "Non-image URL",
input: "https://ario.laisky.com/alias/doc",
wantErr: true,
errMessage: "invalid content type",
},
}
for _, tt := range tests {
tt := tt // capture range variable
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mimeType, data, err := img.GetImageFromUrl(tt.input)
if tt.wantErr {
assert.Error(t, err)
if tt.errMessage != "" {
assert.Contains(t, err.Error(), tt.errMessage)
}
return
}
assert.NoError(t, err)
assert.NotEmpty(t, data)
// For data URLs, we should verify the mime type matches the input
if strings.HasPrefix(tt.input, "data:image/") {
assert.Equal(t, tt.wantMime, mimeType)
return
}
// For regular URLs, verify the base64 data is valid and can be decoded
decoded, err := base64.StdEncoding.DecodeString(data)
assert.NoError(t, err)
assert.NotEmpty(t, decoded)
// Verify the decoded data is a valid image
reader := bytes.NewReader(decoded)
_, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, strings.TrimPrefix(tt.wantMime, "image/"), format)
})
}
}