From 16ddc0af83a3fe4f06f88a663c14fb98eb87f042 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Tue, 1 Apr 2025 11:27:32 +0000 Subject: [PATCH] fix: enhance image URL validation and error handling --- common/image/image.go | 28 +++++++++++- common/image/image_test.go | 92 +++++++++++++++++++++++++++++++++++++- 2 files changed, 117 insertions(+), 3 deletions(-) diff --git a/common/image/image.go b/common/image/image.go index 6a10cb19..14488627 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -14,7 +14,6 @@ import ( "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/client" - _ "golang.org/x/image/webp" ) @@ -26,6 +25,16 @@ func IsImageUrl(url string) (bool, error) { if err != nil { return false, errors.Wrapf(err, "failed to fetch image URL: %s", url) } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + // this file may not support HEAD method + resp, err = client.UserContentRequestHTTPClient.Get(url) + if err != nil { + return false, errors.Wrapf(err, "failed to fetch image URL: %s", url) + } + defer resp.Body.Close() + } if resp.StatusCode != http.StatusOK { return false, errors.Errorf("failed to fetch image URL: %s, status code: %d", url, resp.StatusCode) @@ -35,6 +44,13 @@ func IsImageUrl(url string) (bool, error) { return false, errors.Errorf("image size should not exceed 10MB: %s, size: %d", url, resp.ContentLength) } + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + if !strings.HasPrefix(contentType, "image/") && + !strings.Contains(contentType, "application/octet-stream") { + return false, + errors.Errorf("invalid content type: %s, expected image type", contentType) + } + return true, nil } @@ -51,6 +67,7 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) { return 0, 0, errors.Wrap(err, "failed to get image from URL") } defer resp.Body.Close() + img, _, err := image.DecodeConfig(resp.Body) if err != nil { return 0, 0, errors.Wrap(err, "failed to decode image") @@ -81,11 +98,20 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) { return mimeType, data, errors.Wrap(err, "failed to get image from URL") } defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return mimeType, data, errors.Errorf("failed to fetch image URL: %s, status code: %d", url, resp.StatusCode) + } + if resp.ContentLength > 10*1024*1024 { + return mimeType, data, errors.Errorf("image size should not exceed 10MB: %s, size: %d", url, resp.ContentLength) + } + buffer := bytes.NewBuffer(nil) _, err = buffer.ReadFrom(resp.Body) if err != nil { return mimeType, data, errors.Wrap(err, "failed to read image data from response") } + mimeType = resp.Header.Get("Content-Type") data = base64.StdEncoding.EncodeToString(buffer.Bytes()) return mimeType, data, nil diff --git a/common/image/image_test.go b/common/image/image_test.go index 5b669b51..5774ef1d 100644 --- a/common/image/image_test.go +++ b/common/image/image_test.go @@ -1,8 +1,8 @@ package image_test import ( + "bytes" "encoding/base64" - "github.com/songquanpeng/one-api/common/client" "image" _ "image/gif" _ "image/jpeg" @@ -13,8 +13,8 @@ import ( "strings" "testing" + "github.com/songquanpeng/one-api/common/client" img "github.com/songquanpeng/one-api/common/image" - "github.com/stretchr/testify/assert" _ "golang.org/x/image/webp" ) @@ -51,6 +51,8 @@ func TestMain(m *testing.M) { } func TestDecode(t *testing.T) { + t.Parallel() + // Bytes read: varies sometimes // jpeg: 1063892 // png: 294462 @@ -96,6 +98,8 @@ func TestDecode(t *testing.T) { } func TestBase64(t *testing.T) { + t.Parallel() + // Bytes read: // jpeg: 1063892 // png: 294462 @@ -149,6 +153,8 @@ func TestBase64(t *testing.T) { } func TestGetImageSize(t *testing.T) { + t.Parallel() + for i, c := range cases { t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { width, height, err := img.GetImageSize(c.url) @@ -160,6 +166,8 @@ func TestGetImageSize(t *testing.T) { } func TestGetImageSizeFromBase64(t *testing.T) { + t.Parallel() + for i, c := range cases { t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { resp, err := http.Get(c.url) @@ -175,3 +183,83 @@ func TestGetImageSizeFromBase64(t *testing.T) { }) } } + +func TestGetImageFromUrl(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantMime string + wantErr bool + errMessage string + }{ + { + name: "Valid JPEG URL", + input: cases[0].url, // Using the existing JPEG test case + wantMime: "image/jpeg", + wantErr: false, + }, + { + name: "Valid PNG URL", + input: cases[1].url, // Using the existing PNG test case + wantMime: "image/png", + wantErr: false, + }, + { + name: "Valid Data URL", + input: "", + wantMime: "image/png", + wantErr: false, + }, + { + name: "Invalid URL", + input: "https://invalid.example.com/nonexistent.jpg", + wantErr: true, + errMessage: "failed to fetch image URL", + }, + { + name: "Non-image URL", + input: "https://ario.laisky.com/alias/doc", + wantErr: true, + errMessage: "invalid content type", + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mimeType, data, err := img.GetImageFromUrl(tt.input) + + if tt.wantErr { + assert.Error(t, err) + if tt.errMessage != "" { + assert.Contains(t, err.Error(), tt.errMessage) + } + return + } + + assert.NoError(t, err) + assert.NotEmpty(t, data) + + // For data URLs, we should verify the mime type matches the input + if strings.HasPrefix(tt.input, "data:image/") { + assert.Equal(t, tt.wantMime, mimeType) + return + } + + // For regular URLs, verify the base64 data is valid and can be decoded + decoded, err := base64.StdEncoding.DecodeString(data) + assert.NoError(t, err) + assert.NotEmpty(t, decoded) + + // Verify the decoded data is a valid image + reader := bytes.NewReader(decoded) + _, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, strings.TrimPrefix(tt.wantMime, "image/"), format) + }) + } +}