From f3e255304342d7f08c7a5f34d3b739ce11b20deb Mon Sep 17 00:00:00 2001 From: Michael Katzenellenbogen Date: Tue, 19 Aug 2025 18:00:48 -0400 Subject: [PATCH] refactor how certificates are generated, so that a new one can be created within tests --- pkg/http/http_suite_test.go | 78 ++++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 27 deletions(-) diff --git a/pkg/http/http_suite_test.go b/pkg/http/http_suite_test.go index 19d4d3ff..43d0bc48 100644 --- a/pkg/http/http_suite_test.go +++ b/pkg/http/http_suite_test.go @@ -16,6 +16,7 @@ import ( . "github.com/onsi/gomega" ) +var ipv4Addr, ipv6Addr = "127.0.0.1", "::1" var ipv4CertData, ipv6CertData []byte var ipv4CertDataSource, ipv4KeyDataSource options.SecretSource var ipv6CertDataSource, ipv6KeyDataSource options.SecretSource @@ -40,49 +41,72 @@ func httpGet(ctx context.Context, url string) (*http.Response, error) { return c.Do(req) } +func generateCert(ipaddr string) (certData, certOutBytes, keyOutBytes []byte, err error) { + certBytes, keyBytes, err := util.GenerateCert(ipaddr) + if err != nil { + return + } + certData = certBytes + + certOut := new(bytes.Buffer) + if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}); err != nil { + return + } + certOutBytes = certOut.Bytes() + + keyOut := new(bytes.Buffer) + if err = pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes}); err != nil { + return + } + keyOutBytes = keyOut.Bytes() + + return +} + +func generateX509Cert(certSource, keySource options.SecretSource) (*x509.Certificate, error) { + cert, err := tls.X509KeyPair(certSource.Value, keySource.Value) + if err != nil { + return nil, err + } + + certificate, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil, err + } + + return certificate, nil +} + +func addCertToTransportRootCAs(transport *http.Transport, cert ...*x509.Certificate) { + transport.TLSClientConfig.RootCAs = x509.NewCertPool() + for _, c := range cert { + transport.TLSClientConfig.RootCAs.AddCert(c) + } +} + var _ = BeforeSuite(func() { By("Generating a ipv4 self-signed cert for TLS tests", func() { - certBytes, keyBytes, err := util.GenerateCert("127.0.0.1") + ipv4Cert, ipv4CertBytes, ipv4KeyBytes, err := generateCert(ipv4Addr) Expect(err).ToNot(HaveOccurred()) - ipv4CertData = certBytes - certOut := new(bytes.Buffer) - Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed()) - ipv4CertDataSource.Value = certOut.Bytes() - keyOut := new(bytes.Buffer) - Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes})).To(Succeed()) - ipv4KeyDataSource.Value = keyOut.Bytes() + ipv4CertData, ipv4CertDataSource.Value, ipv4KeyDataSource.Value = ipv4Cert, ipv4CertBytes, ipv4KeyBytes }) By("Generating a ipv6 self-signed cert for TLS tests", func() { - certBytes, keyBytes, err := util.GenerateCert("::1") + ipv6Cert, ipv6CertBytes, ipv6KeyBytes, err := generateCert(ipv6Addr) Expect(err).ToNot(HaveOccurred()) - ipv6CertData = certBytes - certOut := new(bytes.Buffer) - Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed()) - ipv6CertDataSource.Value = certOut.Bytes() - keyOut := new(bytes.Buffer) - Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes})).To(Succeed()) - ipv6KeyDataSource.Value = keyOut.Bytes() + ipv6CertData, ipv6CertDataSource.Value, ipv6KeyDataSource.Value = ipv6Cert, ipv6CertBytes, ipv6KeyBytes }) By("Setting up a http client", func() { - ipv4cert, err := tls.X509KeyPair(ipv4CertDataSource.Value, ipv4KeyDataSource.Value) - Expect(err).ToNot(HaveOccurred()) - ipv6cert, err := tls.X509KeyPair(ipv6CertDataSource.Value, ipv6KeyDataSource.Value) + ipv4certificate, err := generateX509Cert(ipv4CertDataSource, ipv4KeyDataSource) Expect(err).ToNot(HaveOccurred()) - ipv4certificate, err := x509.ParseCertificate(ipv4cert.Certificate[0]) + ipv6certificate, err := generateX509Cert(ipv6CertDataSource, ipv6KeyDataSource) Expect(err).ToNot(HaveOccurred()) - ipv6certificate, err := x509.ParseCertificate(ipv6cert.Certificate[0]) - Expect(err).ToNot(HaveOccurred()) - - certpool := x509.NewCertPool() - certpool.AddCert(ipv4certificate) - certpool.AddCert(ipv6certificate) transport = http.DefaultTransport.(*http.Transport).Clone() - transport.TLSClientConfig.RootCAs = certpool + addCertToTransportRootCAs(transport, ipv4certificate, ipv6certificate) }) })