diff --git a/internal/sub/external_subscription.go b/internal/sub/external_subscription.go index 4419abc75..71a3fabf9 100644 --- a/internal/sub/external_subscription.go +++ b/internal/sub/external_subscription.go @@ -78,14 +78,20 @@ func doFetchSubscriptionLinks(rawURL string) ([]string, error) { if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, errBadStatus } - body, err := io.ReadAll(io.LimitReader(resp.Body, subscriptionMaxBytes)) + body, err := io.ReadAll(io.LimitReader(resp.Body, subscriptionMaxBytes+1)) if err != nil { return nil, err } + if len(body) > subscriptionMaxBytes { + return nil, errSubscriptionBodyTooLarge + } return decodeSubscriptionBody(body), nil } -var errBadStatus = &subError{"non-2xx subscription response"} +var ( + errBadStatus = &subError{"non-2xx subscription response"} + errSubscriptionBodyTooLarge = &subError{"subscription response body exceeds size limit"} +) type subError struct{ msg string } diff --git a/internal/sub/external_subscription_test.go b/internal/sub/external_subscription_test.go new file mode 100644 index 000000000..60af51987 --- /dev/null +++ b/internal/sub/external_subscription_test.go @@ -0,0 +1,43 @@ +package sub + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestDoFetchSubscriptionLinks_RejectsOversizedBody(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(strings.Repeat("a", subscriptionMaxBytes+1))) + })) + defer srv.Close() + + links, err := doFetchSubscriptionLinks(srv.URL) + if err != errSubscriptionBodyTooLarge { + t.Fatalf("err = %v, want errSubscriptionBodyTooLarge", err) + } + if links != nil { + t.Fatalf("links = %v, want nil", links) + } +} + +func TestDoFetchSubscriptionLinks_AcceptsBodyAtLimit(t *testing.T) { + link := "vless://example" + body := link + "\n" + strings.Repeat("#", subscriptionMaxBytes-len(link)-1) + if len(body) != subscriptionMaxBytes { + t.Fatalf("fixture size = %d, want %d", len(body), subscriptionMaxBytes) + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(body)) + })) + defer srv.Close() + + links, err := doFetchSubscriptionLinks(srv.URL) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if len(links) != 1 || links[0] != link { + t.Fatalf("links = %v, want [%q]", links, link) + } +}