From c77efe2300b2e1d747fda14990c2bd3591dcc659 Mon Sep 17 00:00:00 2001 From: Niki Dokovski Date: Fri, 27 Feb 2026 15:23:33 +0100 Subject: [PATCH] feat(requests): add dynamic CA certificate reloading Add DynamicCALoader and DynamicTLSTransport for hot-reloadable CA certificates without requiring application restart. - DynamicCALoader: watches CA files via fsnotify and reloads on change, supports Kubernetes ConfigMap/Secret symlink replacement pattern - DynamicTLSTransport: wraps http.Transport with dynamic CA verification using VerifyPeerCertificate callback - Add atomic transport proxy to http.go for runtime transport swapping Signed-off-by: Niki Dokovski --- pkg/requests/dynamic_ca.go | 69 +++++++++ pkg/requests/dynamic_ca_test.go | 225 ++++++++++++++++++++++++++++++ pkg/requests/dynamic_transport.go | 88 ++++++++++++ pkg/requests/http.go | 30 +++- 4 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 pkg/requests/dynamic_ca.go create mode 100644 pkg/requests/dynamic_ca_test.go create mode 100644 pkg/requests/dynamic_transport.go diff --git a/pkg/requests/dynamic_ca.go b/pkg/requests/dynamic_ca.go new file mode 100644 index 00000000..e5ca8903 --- /dev/null +++ b/pkg/requests/dynamic_ca.go @@ -0,0 +1,69 @@ +package requests + +import ( + "crypto/x509" + "sync/atomic" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/watcher" +) + +// DynamicCALoader provides hot-reloadable CA certificates. +// It caches the certificate pool and reloads from disk when triggered +// by file system events via the watcher. +type DynamicCALoader struct { + caFiles []string + useSystemTrustStore bool + certPool atomic.Pointer[x509.CertPool] +} + +// NewDynamicCALoader creates a new DynamicCALoader that loads CA certificates +// from the specified files. The certificates are reloaded automatically when +// files change on disk. +func NewDynamicCALoader(caFiles []string, useSystemTrustStore bool) (*DynamicCALoader, error) { + d := &DynamicCALoader{ + caFiles: caFiles, + useSystemTrustStore: useSystemTrustStore, + } + if err := d.reload(); err != nil { + return nil, err + } + return d, nil +} + +// GetCertPool returns the current CA certificate pool. +func (d *DynamicCALoader) GetCertPool() (*x509.CertPool, error) { + return d.certPool.Load(), nil +} + +// reload loads CA certificates from disk and updates the cached pool. +func (d *DynamicCALoader) reload() error { + pool, err := util.GetCertPool(d.caFiles, d.useSystemTrustStore) + if err != nil { + return err + } + d.certPool.Store(pool) + logger.Printf("CA certificates reloaded from %v", d.caFiles) + return nil +} + +// ForceReload forces immediate reload of CA certificates. +// This is called by file watcher when CA files change on disk. +func (d *DynamicCALoader) ForceReload() { + if err := d.reload(); err != nil { + logger.Errorf("CA reload failed: %v", err) + } +} + +// StartWatching sets up file watchers for all CA files. +// When any CA file changes, the certificates are automatically reloaded. +// This supports Kubernetes ConfigMap/Secret mounts which use symlink replacement. +func (d *DynamicCALoader) StartWatching(done <-chan bool) error { + for _, caFile := range d.caFiles { + if err := watcher.WatchFileForUpdates(caFile, done, d.ForceReload); err != nil { + return err + } + } + return nil +} diff --git a/pkg/requests/dynamic_ca_test.go b/pkg/requests/dynamic_ca_test.go new file mode 100644 index 00000000..12b57a8d --- /dev/null +++ b/pkg/requests/dynamic_ca_test.go @@ -0,0 +1,225 @@ +package requests + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// generateTestCert creates a self-signed certificate for testing +func generateTestCert(t *testing.T, commonName string) (certPEM []byte) { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: commonName, + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + require.NoError(t, err) + + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + return certPEM +} + +// createTempCAFile creates a temporary CA file for testing +func createTempCAFile(t *testing.T, dir string, certPEM []byte) string { + t.Helper() + + file, err := os.CreateTemp(dir, "ca-*.pem") + require.NoError(t, err) + + _, err = file.Write(certPEM) + require.NoError(t, err) + require.NoError(t, file.Close()) + + return file.Name() +} + +func TestNewDynamicCALoader(t *testing.T) { + tempDir := t.TempDir() + certPEM := generateTestCert(t, "Test CA") + caFile := createTempCAFile(t, tempDir, certPEM) + + loader, err := NewDynamicCALoader([]string{caFile}, false) + require.NoError(t, err) + require.NotNil(t, loader) + + pool, err := loader.GetCertPool() + require.NoError(t, err) + require.NotNil(t, pool) +} + +func TestNewDynamicCALoader_InvalidFile(t *testing.T) { + loader, err := NewDynamicCALoader([]string{"/nonexistent/ca.pem"}, false) + assert.Error(t, err) + assert.Nil(t, loader) +} + +func TestNewDynamicCALoader_EmptyFiles(t *testing.T) { + loader, err := NewDynamicCALoader([]string{}, false) + assert.Error(t, err) + assert.Nil(t, loader) +} + +func TestDynamicCALoader_GetCertPool_ReturnsCached(t *testing.T) { + tempDir := t.TempDir() + certPEM := generateTestCert(t, "Test CA") + caFile := createTempCAFile(t, tempDir, certPEM) + + loader, err := NewDynamicCALoader([]string{caFile}, false) + require.NoError(t, err) + + // Multiple calls should return the same cached pool + pool1, err := loader.GetCertPool() + require.NoError(t, err) + + pool2, err := loader.GetCertPool() + require.NoError(t, err) + + // Both should be the same pointer (cached) + assert.Same(t, pool1, pool2) +} + +func TestDynamicCALoader_ForceReload(t *testing.T) { + tempDir := t.TempDir() + certPEM := generateTestCert(t, "Test CA") + caFile := createTempCAFile(t, tempDir, certPEM) + + loader, err := NewDynamicCALoader([]string{caFile}, false) + require.NoError(t, err) + + pool1, err := loader.GetCertPool() + require.NoError(t, err) + + // Update the file with new cert + newCertPEM := generateTestCert(t, "New Test CA") + err = os.WriteFile(caFile, newCertPEM, 0600) + require.NoError(t, err) + + // Force reload + loader.ForceReload() + + pool2, err := loader.GetCertPool() + require.NoError(t, err) + + // Pools should be different after reload + assert.NotSame(t, pool1, pool2) +} + +func TestDynamicCALoader_ForceReload_ErrorKeepsOldCerts(t *testing.T) { + tempDir := t.TempDir() + certPEM := generateTestCert(t, "Test CA") + caFile := createTempCAFile(t, tempDir, certPEM) + + loader, err := NewDynamicCALoader([]string{caFile}, false) + require.NoError(t, err) + + pool1, err := loader.GetCertPool() + require.NoError(t, err) + + // Delete the file to cause reload error + err = os.Remove(caFile) + require.NoError(t, err) + + // ForceReload should fail but keep old certs + loader.ForceReload() + + // GetCertPool should still return the old pool + pool2, err := loader.GetCertPool() + require.NoError(t, err) + assert.Same(t, pool1, pool2) +} + +func TestDynamicCALoader_MultipleCertFiles(t *testing.T) { + tempDir := t.TempDir() + + cert1PEM := generateTestCert(t, "Test CA 1") + cert2PEM := generateTestCert(t, "Test CA 2") + + caFile1 := createTempCAFile(t, tempDir, cert1PEM) + caFile2 := createTempCAFile(t, tempDir, cert2PEM) + + loader, err := NewDynamicCALoader([]string{caFile1, caFile2}, false) + require.NoError(t, err) + + pool, err := loader.GetCertPool() + require.NoError(t, err) + require.NotNil(t, pool) +} + +func TestDynamicTLSTransport_Creation(t *testing.T) { + tempDir := t.TempDir() + certPEM := generateTestCert(t, "Test CA") + caFile := createTempCAFile(t, tempDir, certPEM) + + loader, err := NewDynamicCALoader([]string{caFile}, false) + require.NoError(t, err) + + transport := NewDynamicTLSTransport(loader) + require.NotNil(t, transport) + require.NotNil(t, transport.base) + require.NotNil(t, transport.base.TLSClientConfig) +} + +func TestDynamicCALoader_StartWatching(t *testing.T) { + tempDir := t.TempDir() + certPEM := generateTestCert(t, "Test CA") + caFile := filepath.Join(tempDir, "ca.pem") + err := os.WriteFile(caFile, certPEM, 0600) + require.NoError(t, err) + + loader, err := NewDynamicCALoader([]string{caFile}, false) + require.NoError(t, err) + + done := make(chan bool) + defer close(done) + + err = loader.StartWatching(done) + require.NoError(t, err) + + // Give watcher time to start + time.Sleep(100 * time.Millisecond) + + pool1, err := loader.GetCertPool() + require.NoError(t, err) + + // Update the file + newCertPEM := generateTestCert(t, "New Test CA") + err = os.WriteFile(caFile, newCertPEM, 0600) + require.NoError(t, err) + + // Give watcher time to detect and reload + time.Sleep(200 * time.Millisecond) + + pool2, err := loader.GetCertPool() + require.NoError(t, err) + + // After file change, the pool should be reloaded + assert.NotSame(t, pool1, pool2) +} diff --git a/pkg/requests/dynamic_transport.go b/pkg/requests/dynamic_transport.go new file mode 100644 index 00000000..0c7db611 --- /dev/null +++ b/pkg/requests/dynamic_transport.go @@ -0,0 +1,88 @@ +package requests + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" +) + +// DynamicTLSTransport wraps http.Transport with dynamic CA verification. +// It uses VerifyPeerCertificate callback to perform certificate verification +// with a dynamically reloadable CA pool, enabling hot-reload of CA certificates +// without restarting the application. +type DynamicTLSTransport struct { + base *http.Transport + caLoader *DynamicCALoader +} + +// NewDynamicTLSTransport creates a new transport that uses the provided +// DynamicCALoader for certificate verification. The CA certificates can +// be reloaded at runtime without affecting in-flight requests. +func NewDynamicTLSTransport(caLoader *DynamicCALoader) *DynamicTLSTransport { + base := http.DefaultTransport.(*http.Transport).Clone() + + dt := &DynamicTLSTransport{ + base: base, + caLoader: caLoader, + } + + base.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + // We set InsecureSkipVerify to true because we perform our own + // verification in VerifyPeerCertificate. This allows us to use + // a dynamically updated CA pool for verification. + InsecureSkipVerify: true, // #nosec G402 -- verification done in VerifyPeerCertificate + VerifyPeerCertificate: dt.verifyPeerCertificate, + } + + return dt +} + +// RoundTrip implements http.RoundTripper +func (dt *DynamicTLSTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return dt.base.RoundTrip(req) +} + +// verifyPeerCertificate is called during TLS handshake to verify the server's +// certificate chain using the current CA pool from the DynamicCALoader. +// This enables hot-reload of CA certificates - each new connection will use +// the latest CA certificates. +func (dt *DynamicTLSTransport) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return fmt.Errorf("no certificates provided by server") + } + + pool, err := dt.caLoader.GetCertPool() + if err != nil { + return fmt.Errorf("failed to get CA pool: %w", err) + } + + // Parse all certificates in the chain + certs := make([]*x509.Certificate, len(rawCerts)) + for i, raw := range rawCerts { + cert, err := x509.ParseCertificate(raw) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + certs[i] = cert + } + + // Build verification options with the dynamic CA pool + opts := x509.VerifyOptions{ + Roots: pool, + Intermediates: x509.NewCertPool(), + } + + // Add intermediate certificates (all except the leaf) + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + // Verify the leaf certificate + if _, err := certs[0].Verify(opts); err != nil { + return fmt.Errorf("certificate verification failed: %w", err) + } + + return nil +} diff --git a/pkg/requests/http.go b/pkg/requests/http.go index c0035e0a..ee05c23c 100644 --- a/pkg/requests/http.go +++ b/pkg/requests/http.go @@ -2,6 +2,7 @@ package requests import ( "net/http" + "sync/atomic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/version" ) @@ -17,13 +18,40 @@ func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error return t.next.RoundTrip(r) } +// defaultTransport holds a pointer to the current transport (can be swapped atomically) +var defaultTransport atomic.Pointer[http.RoundTripper] + +// transportProxy implements http.RoundTripper and delegates to the atomic pointer. +// This allows the underlying transport to be swapped at runtime for CA reload. +type transportProxy struct{} + +func (t *transportProxy) RoundTrip(req *http.Request) (*http.Response, error) { + transport := defaultTransport.Load() + if transport == nil { + return http.DefaultTransport.RoundTrip(req) + } + return (*transport).RoundTrip(req) +} + +// DefaultHTTPClient is the shared HTTP client used for provider requests. +// It uses a transport proxy that supports runtime transport swapping for CA reload. var DefaultHTTPClient = &http.Client{Transport: &userAgentTransport{ - next: DefaultTransport, + next: &transportProxy{}, userAgent: "oauth2-proxy/" + version.VERSION, }} +// DefaultTransport is kept for backward compatibility. +// New code should use SetDefaultTransport to update the transport. var DefaultTransport = http.DefaultTransport +// SetDefaultTransport atomically sets the default transport used by DefaultHTTPClient. +// This is used to enable dynamic CA certificate reloading. +func SetDefaultTransport(rt http.RoundTripper) { + defaultTransport.Store(&rt) + // Also update the legacy variable for backward compatibility with existing code + DefaultTransport = rt +} + func setDefaultUserAgent(header http.Header, userAgent string) { if header != nil && len(header.Values("User-Agent")) == 0 { header.Set("User-Agent", userAgent)