diff --git a/common/image/image.go b/common/image/image.go index beebd0c6..6a10cb19 100644 --- a/common/image/image.go +++ b/common/image/image.go @@ -3,7 +3,6 @@ package image import ( "bytes" "encoding/base64" - "github.com/songquanpeng/one-api/common/client" "image" _ "image/gif" _ "image/jpeg" @@ -13,6 +12,9 @@ import ( "strings" "sync" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/client" + _ "golang.org/x/image/webp" ) @@ -22,27 +24,36 @@ var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) func IsImageUrl(url string) (bool, error) { resp, err := client.UserContentRequestHTTPClient.Head(url) if err != nil { - return false, err + return false, errors.Wrapf(err, "failed to fetch image URL: %s", url) } - if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { - return false, nil + + if resp.StatusCode != http.StatusOK { + return false, errors.Errorf("failed to fetch image URL: %s, status code: %d", url, resp.StatusCode) } + + if resp.ContentLength > 10*1024*1024 { + return false, errors.Errorf("image size should not exceed 10MB: %s, size: %d", url, resp.ContentLength) + } + return true, nil } func GetImageSizeFromUrl(url string) (width int, height int, err error) { isImage, err := IsImageUrl(url) + if err != nil { + return 0, 0, errors.Wrap(err, "failed to fetch image URL") + } if !isImage { - return + return 0, 0, errors.New("not an image URL") } resp, err := client.UserContentRequestHTTPClient.Get(url) if err != nil { - return + return 0, 0, errors.Wrap(err, "failed to get image from URL") } defer resp.Body.Close() img, _, err := image.DecodeConfig(resp.Body) if err != nil { - return + return 0, 0, errors.Wrap(err, "failed to decode image") } return img.Width, img.Height, nil } @@ -58,22 +69,26 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) { } isImage, err := IsImageUrl(url) - if !isImage { - return + if err != nil { + return mimeType, data, errors.Wrap(err, "failed to fetch image URL") } + if !isImage { + return mimeType, data, errors.New("not an image URL") + } + resp, err := http.Get(url) if err != nil { - return + return mimeType, data, errors.Wrap(err, "failed to get image from URL") } defer resp.Body.Close() buffer := bytes.NewBuffer(nil) _, err = buffer.ReadFrom(resp.Body) if err != nil { - return + return mimeType, data, errors.Wrap(err, "failed to read image data from response") } mimeType = resp.Header.Get("Content-Type") data = base64.StdEncoding.EncodeToString(buffer.Bytes()) - return + return mimeType, data, nil } var (