Compare commits

...

3 Commits

Author SHA1 Message Date
Laisky.Cai
16ddc0af83 fix: enhance image URL validation and error handling 2025-04-01 11:28:21 +00:00
Laisky.Cai
747e84851e fix: improve error handling for image URL fetching and validation 2025-04-01 10:16:50 +00:00
CaiCandong
ecadb19791 fix 2025-04-01 10:16:35 +00:00
6 changed files with 148 additions and 19 deletions

View File

@ -3,7 +3,6 @@ package image
import (
"bytes"
"encoding/base64"
"github.com/songquanpeng/one-api/common/client"
"image"
_ "image/gif"
_ "image/jpeg"
@ -13,6 +12,8 @@ import (
"strings"
"sync"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/client"
_ "golang.org/x/image/webp"
)
@ -22,27 +23,54 @@ var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
func IsImageUrl(url string) (bool, error) {
resp, err := client.UserContentRequestHTTPClient.Head(url)
if err != nil {
return false, err
return false, errors.Wrapf(err, "failed to fetch image URL: %s", url)
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// this file may not support HEAD method
resp, err = client.UserContentRequestHTTPClient.Get(url)
if err != nil {
return false, errors.Wrapf(err, "failed to fetch image URL: %s", url)
}
defer resp.Body.Close()
}
if resp.StatusCode != http.StatusOK {
return false, errors.Errorf("failed to fetch image URL: %s, status code: %d", url, resp.StatusCode)
}
if resp.ContentLength > 10*1024*1024 {
return false, errors.Errorf("image size should not exceed 10MB: %s, size: %d", url, resp.ContentLength)
}
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
if !strings.HasPrefix(contentType, "image/") &&
!strings.Contains(contentType, "application/octet-stream") {
return false,
errors.Errorf("invalid content type: %s, expected image type", contentType)
}
return true, nil
}
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if err != nil {
return 0, 0, errors.Wrap(err, "failed to fetch image URL")
}
if !isImage {
return
return 0, 0, errors.New("not an image URL")
}
resp, err := client.UserContentRequestHTTPClient.Get(url)
if err != nil {
return
return 0, 0, errors.Wrap(err, "failed to get image from URL")
}
defer resp.Body.Close()
img, _, err := image.DecodeConfig(resp.Body)
if err != nil {
return
return 0, 0, errors.Wrap(err, "failed to decode image")
}
return img.Width, img.Height, nil
}
@ -58,22 +86,35 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
}
isImage, err := IsImageUrl(url)
if !isImage {
return
if err != nil {
return mimeType, data, errors.Wrap(err, "failed to fetch image URL")
}
if !isImage {
return mimeType, data, errors.New("not an image URL")
}
resp, err := http.Get(url)
if err != nil {
return
return mimeType, data, errors.Wrap(err, "failed to get image from URL")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return mimeType, data, errors.Errorf("failed to fetch image URL: %s, status code: %d", url, resp.StatusCode)
}
if resp.ContentLength > 10*1024*1024 {
return mimeType, data, errors.Errorf("image size should not exceed 10MB: %s, size: %d", url, resp.ContentLength)
}
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
return mimeType, data, errors.Wrap(err, "failed to read image data from response")
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
return mimeType, data, nil
}
var (

View File

@ -1,8 +1,8 @@
package image_test
import (
"bytes"
"encoding/base64"
"github.com/songquanpeng/one-api/common/client"
"image"
_ "image/gif"
_ "image/jpeg"
@ -13,8 +13,8 @@ import (
"strings"
"testing"
"github.com/songquanpeng/one-api/common/client"
img "github.com/songquanpeng/one-api/common/image"
"github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp"
)
@ -51,6 +51,8 @@ func TestMain(m *testing.M) {
}
func TestDecode(t *testing.T) {
t.Parallel()
// Bytes read: varies sometimes
// jpeg: 1063892
// png: 294462
@ -96,6 +98,8 @@ func TestDecode(t *testing.T) {
}
func TestBase64(t *testing.T) {
t.Parallel()
// Bytes read:
// jpeg: 1063892
// png: 294462
@ -149,6 +153,8 @@ func TestBase64(t *testing.T) {
}
func TestGetImageSize(t *testing.T) {
t.Parallel()
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
width, height, err := img.GetImageSize(c.url)
@ -160,6 +166,8 @@ func TestGetImageSize(t *testing.T) {
}
func TestGetImageSizeFromBase64(t *testing.T) {
t.Parallel()
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url)
@ -175,3 +183,83 @@ func TestGetImageSizeFromBase64(t *testing.T) {
})
}
}
func TestGetImageFromUrl(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantMime string
wantErr bool
errMessage string
}{
{
name: "Valid JPEG URL",
input: cases[0].url, // Using the existing JPEG test case
wantMime: "image/jpeg",
wantErr: false,
},
{
name: "Valid PNG URL",
input: cases[1].url, // Using the existing PNG test case
wantMime: "image/png",
wantErr: false,
},
{
name: "Valid Data URL",
input: "",
wantMime: "image/png",
wantErr: false,
},
{
name: "Invalid URL",
input: "https://invalid.example.com/nonexistent.jpg",
wantErr: true,
errMessage: "failed to fetch image URL",
},
{
name: "Non-image URL",
input: "https://ario.laisky.com/alias/doc",
wantErr: true,
errMessage: "invalid content type",
},
}
for _, tt := range tests {
tt := tt // capture range variable
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mimeType, data, err := img.GetImageFromUrl(tt.input)
if tt.wantErr {
assert.Error(t, err)
if tt.errMessage != "" {
assert.Contains(t, err.Error(), tt.errMessage)
}
return
}
assert.NoError(t, err)
assert.NotEmpty(t, data)
// For data URLs, we should verify the mime type matches the input
if strings.HasPrefix(tt.input, "data:image/") {
assert.Equal(t, tt.wantMime, mimeType)
return
}
// For regular URLs, verify the base64 data is valid and can be decoded
decoded, err := base64.StdEncoding.DecodeString(data)
assert.NoError(t, err)
assert.NotEmpty(t, decoded)
// Verify the decoded data is a valid image
reader := bytes.NewReader(decoded)
_, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, strings.TrimPrefix(tt.wantMime, "image/"), format)
})
}
}

View File

@ -327,7 +327,7 @@ const ChannelsTable = () => {
let res;
switch (action) {
case 'delete':
res = await API.delete(`/api/channel/${id}/`);
res = await API.delete(`/api/channel/${id}`);
break;
case 'enable':
data.status = 1;

View File

@ -250,7 +250,7 @@ const RedemptionsTable = () => {
let res;
switch (action) {
case 'delete':
res = await API.delete(`/api/redemption/${id}/`);
res = await API.delete(`/api/redemption/${id}`);
break;
case 'enable':
data.status = 1;

View File

@ -165,7 +165,7 @@ const ChannelsTable = () => {
let res;
switch (action) {
case 'delete':
res = await API.delete(`/api/channel/${id}/`);
res = await API.delete(`/api/channel/${id}`);
break;
case 'enable':
data.status = 1;
@ -360,7 +360,7 @@ const ChannelsTable = () => {
};
const updateChannelBalance = async (id, name, idx) => {
const res = await API.get(`/api/channel/update_balance/${id}/`);
const res = await API.get(`/api/channel/update_balance/${id}`);
const { success, message, balance } = res.data;
if (success) {
let newChannels = [...channels];

View File

@ -103,7 +103,7 @@ const RedemptionsTable = () => {
let res;
switch (action) {
case 'delete':
res = await API.delete(`/api/redemption/${id}/`);
res = await API.delete(`/api/redemption/${id}`);
break;
case 'enable':
data.status = 1;