package image_test import ( "bytes" "encoding/base64" "image" _ "image/gif" _ "image/jpeg" _ "image/png" "io" "net/http" "strconv" "strings" "testing" "github.com/songquanpeng/one-api/common/client" img "github.com/songquanpeng/one-api/common/image" "github.com/stretchr/testify/assert" _ "golang.org/x/image/webp" ) type CountingReader struct { reader io.Reader BytesRead int } func (r *CountingReader) Read(p []byte) (n int, err error) { n, err = r.reader.Read(p) r.BytesRead += n return n, err } var ( cases = []struct { url string format string width int height int }{ {"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669}, {"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592}, {"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985}, {"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533}, {"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230}, } ) func TestMain(m *testing.M) { client.Init() m.Run() } func TestDecode(t *testing.T) { t.Parallel() // Bytes read: varies sometimes // jpeg: 1063892 // png: 294462 // webp: 99529 // gif: 956153 // jpeg#01: 32805 for _, c := range cases { t.Run("Decode:"+c.format, func(t *testing.T) { resp, err := http.Get(c.url) assert.NoError(t, err) defer resp.Body.Close() reader := &CountingReader{reader: resp.Body} img, format, err := image.Decode(reader) assert.NoError(t, err) size := img.Bounds().Size() assert.Equal(t, c.format, format) assert.Equal(t, c.width, size.X) assert.Equal(t, c.height, size.Y) t.Logf("Bytes read: %d", reader.BytesRead) }) } // Bytes read: // jpeg: 4096 // png: 4096 // webp: 4096 // gif: 4096 // jpeg#01: 4096 for _, c := range cases { t.Run("DecodeConfig:"+c.format, func(t *testing.T) { resp, err := http.Get(c.url) assert.NoError(t, err) defer resp.Body.Close() reader := &CountingReader{reader: resp.Body} config, format, err := image.DecodeConfig(reader) assert.NoError(t, err) assert.Equal(t, c.format, format) assert.Equal(t, c.width, config.Width) assert.Equal(t, c.height, config.Height) t.Logf("Bytes read: %d", reader.BytesRead) }) } } func TestBase64(t *testing.T) { t.Parallel() // Bytes read: // jpeg: 1063892 // png: 294462 // webp: 99072 // gif: 953856 // jpeg#01: 32805 for _, c := range cases { t.Run("Decode:"+c.format, func(t *testing.T) { resp, err := http.Get(c.url) assert.NoError(t, err) defer resp.Body.Close() data, err := io.ReadAll(resp.Body) assert.NoError(t, err) encoded := base64.StdEncoding.EncodeToString(data) body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) reader := &CountingReader{reader: body} img, format, err := image.Decode(reader) assert.NoError(t, err) size := img.Bounds().Size() assert.Equal(t, c.format, format) assert.Equal(t, c.width, size.X) assert.Equal(t, c.height, size.Y) t.Logf("Bytes read: %d", reader.BytesRead) }) } // Bytes read: // jpeg: 1536 // png: 768 // webp: 768 // gif: 1536 // jpeg#01: 3840 for _, c := range cases { t.Run("DecodeConfig:"+c.format, func(t *testing.T) { resp, err := http.Get(c.url) assert.NoError(t, err) defer resp.Body.Close() data, err := io.ReadAll(resp.Body) assert.NoError(t, err) encoded := base64.StdEncoding.EncodeToString(data) body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) reader := &CountingReader{reader: body} config, format, err := image.DecodeConfig(reader) assert.NoError(t, err) assert.Equal(t, c.format, format) assert.Equal(t, c.width, config.Width) assert.Equal(t, c.height, config.Height) t.Logf("Bytes read: %d", reader.BytesRead) }) } } func TestGetImageSize(t *testing.T) { t.Parallel() for i, c := range cases { t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { width, height, err := img.GetImageSize(c.url) assert.NoError(t, err) assert.Equal(t, c.width, width) assert.Equal(t, c.height, height) }) } } func TestGetImageSizeFromBase64(t *testing.T) { t.Parallel() for i, c := range cases { t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { resp, err := http.Get(c.url) assert.NoError(t, err) defer resp.Body.Close() data, err := io.ReadAll(resp.Body) assert.NoError(t, err) encoded := base64.StdEncoding.EncodeToString(data) width, height, err := img.GetImageSizeFromBase64(encoded) assert.NoError(t, err) assert.Equal(t, c.width, width) assert.Equal(t, c.height, height) }) } } func TestGetImageFromUrl(t *testing.T) { t.Parallel() tests := []struct { name string input string wantMime string wantErr bool errMessage string }{ { name: "Valid JPEG URL", input: cases[0].url, // Using the existing JPEG test case wantMime: "image/jpeg", wantErr: false, }, { name: "Valid PNG URL", input: cases[1].url, // Using the existing PNG test case wantMime: "image/png", wantErr: false, }, { name: "Valid Data URL", input: "", wantMime: "image/png", wantErr: false, }, { name: "Invalid URL", input: "https://invalid.example.com/nonexistent.jpg", wantErr: true, errMessage: "failed to fetch image URL", }, { name: "Non-image URL", input: "https://ario.laisky.com/alias/doc", wantErr: true, errMessage: "invalid content type", }, } for _, tt := range tests { tt := tt // capture range variable t.Run(tt.name, func(t *testing.T) { t.Parallel() mimeType, data, err := img.GetImageFromUrl(tt.input) if tt.wantErr { assert.Error(t, err) if tt.errMessage != "" { assert.Contains(t, err.Error(), tt.errMessage) } return } assert.NoError(t, err) assert.NotEmpty(t, data) // For data URLs, we should verify the mime type matches the input if strings.HasPrefix(tt.input, "data:image/") { assert.Equal(t, tt.wantMime, mimeType) return } // For regular URLs, verify the base64 data is valid and can be decoded decoded, err := base64.StdEncoding.DecodeString(data) assert.NoError(t, err) assert.NotEmpty(t, decoded) // Verify the decoded data is a valid image reader := bytes.NewReader(decoded) _, format, err := image.DecodeConfig(reader) assert.NoError(t, err) assert.Equal(t, strings.TrimPrefix(tt.wantMime, "image/"), format) }) } }