diff --git a/CHANGELOG.md b/CHANGELOG.md index 94625408..ba4eb149 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ ## Changes since v7.6.0 +- [#2539](https://github.com/oauth2-proxy/oauth2-proxy/pull/2539) pkg/http: Fix leaky test (@isodude) + # V7.6.0 ## Release Highlights diff --git a/pkg/http/http_suite_test.go b/pkg/http/http_suite_test.go index 79aa19a8..edd440c7 100644 --- a/pkg/http/http_suite_test.go +++ b/pkg/http/http_suite_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/pem" @@ -18,7 +19,7 @@ import ( var ipv4CertData, ipv6CertData []byte var ipv4CertDataSource, ipv4KeyDataSource options.SecretSource var ipv6CertDataSource, ipv6KeyDataSource options.SecretSource -var client *http.Client +var transport *http.Transport func TestHTTPSuite(t *testing.T) { logger.SetOutput(GinkgoWriter) @@ -28,6 +29,17 @@ func TestHTTPSuite(t *testing.T) { RunSpecs(t, "HTTP") } +func httpGet(ctx context.Context, url string) (*http.Response, error) { + c := &http.Client{ + Transport: transport.Clone(), + } + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + var _ = BeforeSuite(func() { By("Generating a ipv4 self-signed cert for TLS tests", func() { certBytes, keyBytes, err := util.GenerateCert("127.0.0.1") @@ -70,11 +82,7 @@ var _ = BeforeSuite(func() { certpool.AddCert(ipv4certificate) certpool.AddCert(ipv6certificate) - transport := http.DefaultTransport.(*http.Transport).Clone() + transport = http.DefaultTransport.(*http.Transport).Clone() transport.TLSClientConfig.RootCAs = certpool - - client = &http.Client{ - Transport: transport, - } }) }) diff --git a/pkg/http/server_test.go b/pkg/http/server_test.go index 4d4ab860..aae7458e 100644 --- a/pkg/http/server_test.go +++ b/pkg/http/server_test.go @@ -12,6 +12,7 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" + . "github.com/onsi/gomega/gleak" ) const hello = "Hello World!" @@ -559,6 +560,8 @@ var _ = Describe("Server", func() { AfterEach(func() { cancel() + Eventually(Goroutines).ShouldNot(HaveLeaked()) + }) Context("with an ipv4 http server", func() { @@ -584,7 +587,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(listenAddr) + resp, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -599,13 +602,13 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) return err }).Should(HaveOccurred()) }) @@ -638,7 +641,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -653,13 +656,13 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) return err }).Should(HaveOccurred()) }) @@ -670,7 +673,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -709,7 +712,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(listenAddr) + resp, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -724,7 +727,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -739,19 +742,19 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) - _, err = client.Get(secureListenAddr) + _, err = httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) return err }).Should(HaveOccurred()) Eventually(func() error { - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) return err }).Should(HaveOccurred()) }) @@ -781,7 +784,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(listenAddr) + resp, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -796,13 +799,13 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) return err }).Should(HaveOccurred()) }) @@ -836,7 +839,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -851,13 +854,13 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) return err }).Should(HaveOccurred()) }) @@ -868,7 +871,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -908,7 +911,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(listenAddr) + resp, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -923,7 +926,7 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - resp, err := client.Get(secureListenAddr) + resp, err := httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(http.StatusOK)) @@ -938,19 +941,19 @@ var _ = Describe("Server", func() { Expect(srv.Start(ctx)).To(Succeed()) }() - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) Expect(err).ToNot(HaveOccurred()) - _, err = client.Get(secureListenAddr) + _, err = httpGet(ctx, secureListenAddr) Expect(err).ToNot(HaveOccurred()) cancel() Eventually(func() error { - _, err := client.Get(listenAddr) + _, err := httpGet(ctx, listenAddr) return err }).Should(HaveOccurred()) Eventually(func() error { - _, err := client.Get(secureListenAddr) + _, err := httpGet(ctx, secureListenAddr) return err }).Should(HaveOccurred()) })