feat(requests): add dynamic CA certificate reloading

Add DynamicCALoader and DynamicTLSTransport for hot-reloadable CA
certificates without requiring application restart.

- DynamicCALoader: watches CA files via fsnotify and reloads on change,
  supports Kubernetes ConfigMap/Secret symlink replacement pattern
- DynamicTLSTransport: wraps http.Transport with dynamic CA verification
  using VerifyPeerCertificate callback
- Add atomic transport proxy to http.go for runtime transport swapping

Signed-off-by: Niki Dokovski <nickytd@gmail.com>
This commit is contained in:
Niki Dokovski 2026-02-27 15:23:33 +01:00
parent b5c8df7988
commit c77efe2300
No known key found for this signature in database
4 changed files with 411 additions and 1 deletions

View File

@ -0,0 +1,69 @@
package requests
import (
"crypto/x509"
"sync/atomic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/watcher"
)
// DynamicCALoader provides hot-reloadable CA certificates.
// It caches the certificate pool and reloads from disk when triggered
// by file system events via the watcher.
type DynamicCALoader struct {
caFiles []string
useSystemTrustStore bool
certPool atomic.Pointer[x509.CertPool]
}
// NewDynamicCALoader creates a new DynamicCALoader that loads CA certificates
// from the specified files. The certificates are reloaded automatically when
// files change on disk.
func NewDynamicCALoader(caFiles []string, useSystemTrustStore bool) (*DynamicCALoader, error) {
d := &DynamicCALoader{
caFiles: caFiles,
useSystemTrustStore: useSystemTrustStore,
}
if err := d.reload(); err != nil {
return nil, err
}
return d, nil
}
// GetCertPool returns the current CA certificate pool.
func (d *DynamicCALoader) GetCertPool() (*x509.CertPool, error) {
return d.certPool.Load(), nil
}
// reload loads CA certificates from disk and updates the cached pool.
func (d *DynamicCALoader) reload() error {
pool, err := util.GetCertPool(d.caFiles, d.useSystemTrustStore)
if err != nil {
return err
}
d.certPool.Store(pool)
logger.Printf("CA certificates reloaded from %v", d.caFiles)
return nil
}
// ForceReload forces immediate reload of CA certificates.
// This is called by file watcher when CA files change on disk.
func (d *DynamicCALoader) ForceReload() {
if err := d.reload(); err != nil {
logger.Errorf("CA reload failed: %v", err)
}
}
// StartWatching sets up file watchers for all CA files.
// When any CA file changes, the certificates are automatically reloaded.
// This supports Kubernetes ConfigMap/Secret mounts which use symlink replacement.
func (d *DynamicCALoader) StartWatching(done <-chan bool) error {
for _, caFile := range d.caFiles {
if err := watcher.WatchFileForUpdates(caFile, done, d.ForceReload); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,225 @@
package requests
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// generateTestCert creates a self-signed certificate for testing
func generateTestCert(t *testing.T, commonName string) (certPEM []byte) {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: commonName,
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
return certPEM
}
// createTempCAFile creates a temporary CA file for testing
func createTempCAFile(t *testing.T, dir string, certPEM []byte) string {
t.Helper()
file, err := os.CreateTemp(dir, "ca-*.pem")
require.NoError(t, err)
_, err = file.Write(certPEM)
require.NoError(t, err)
require.NoError(t, file.Close())
return file.Name()
}
func TestNewDynamicCALoader(t *testing.T) {
tempDir := t.TempDir()
certPEM := generateTestCert(t, "Test CA")
caFile := createTempCAFile(t, tempDir, certPEM)
loader, err := NewDynamicCALoader([]string{caFile}, false)
require.NoError(t, err)
require.NotNil(t, loader)
pool, err := loader.GetCertPool()
require.NoError(t, err)
require.NotNil(t, pool)
}
func TestNewDynamicCALoader_InvalidFile(t *testing.T) {
loader, err := NewDynamicCALoader([]string{"/nonexistent/ca.pem"}, false)
assert.Error(t, err)
assert.Nil(t, loader)
}
func TestNewDynamicCALoader_EmptyFiles(t *testing.T) {
loader, err := NewDynamicCALoader([]string{}, false)
assert.Error(t, err)
assert.Nil(t, loader)
}
func TestDynamicCALoader_GetCertPool_ReturnsCached(t *testing.T) {
tempDir := t.TempDir()
certPEM := generateTestCert(t, "Test CA")
caFile := createTempCAFile(t, tempDir, certPEM)
loader, err := NewDynamicCALoader([]string{caFile}, false)
require.NoError(t, err)
// Multiple calls should return the same cached pool
pool1, err := loader.GetCertPool()
require.NoError(t, err)
pool2, err := loader.GetCertPool()
require.NoError(t, err)
// Both should be the same pointer (cached)
assert.Same(t, pool1, pool2)
}
func TestDynamicCALoader_ForceReload(t *testing.T) {
tempDir := t.TempDir()
certPEM := generateTestCert(t, "Test CA")
caFile := createTempCAFile(t, tempDir, certPEM)
loader, err := NewDynamicCALoader([]string{caFile}, false)
require.NoError(t, err)
pool1, err := loader.GetCertPool()
require.NoError(t, err)
// Update the file with new cert
newCertPEM := generateTestCert(t, "New Test CA")
err = os.WriteFile(caFile, newCertPEM, 0600)
require.NoError(t, err)
// Force reload
loader.ForceReload()
pool2, err := loader.GetCertPool()
require.NoError(t, err)
// Pools should be different after reload
assert.NotSame(t, pool1, pool2)
}
func TestDynamicCALoader_ForceReload_ErrorKeepsOldCerts(t *testing.T) {
tempDir := t.TempDir()
certPEM := generateTestCert(t, "Test CA")
caFile := createTempCAFile(t, tempDir, certPEM)
loader, err := NewDynamicCALoader([]string{caFile}, false)
require.NoError(t, err)
pool1, err := loader.GetCertPool()
require.NoError(t, err)
// Delete the file to cause reload error
err = os.Remove(caFile)
require.NoError(t, err)
// ForceReload should fail but keep old certs
loader.ForceReload()
// GetCertPool should still return the old pool
pool2, err := loader.GetCertPool()
require.NoError(t, err)
assert.Same(t, pool1, pool2)
}
func TestDynamicCALoader_MultipleCertFiles(t *testing.T) {
tempDir := t.TempDir()
cert1PEM := generateTestCert(t, "Test CA 1")
cert2PEM := generateTestCert(t, "Test CA 2")
caFile1 := createTempCAFile(t, tempDir, cert1PEM)
caFile2 := createTempCAFile(t, tempDir, cert2PEM)
loader, err := NewDynamicCALoader([]string{caFile1, caFile2}, false)
require.NoError(t, err)
pool, err := loader.GetCertPool()
require.NoError(t, err)
require.NotNil(t, pool)
}
func TestDynamicTLSTransport_Creation(t *testing.T) {
tempDir := t.TempDir()
certPEM := generateTestCert(t, "Test CA")
caFile := createTempCAFile(t, tempDir, certPEM)
loader, err := NewDynamicCALoader([]string{caFile}, false)
require.NoError(t, err)
transport := NewDynamicTLSTransport(loader)
require.NotNil(t, transport)
require.NotNil(t, transport.base)
require.NotNil(t, transport.base.TLSClientConfig)
}
func TestDynamicCALoader_StartWatching(t *testing.T) {
tempDir := t.TempDir()
certPEM := generateTestCert(t, "Test CA")
caFile := filepath.Join(tempDir, "ca.pem")
err := os.WriteFile(caFile, certPEM, 0600)
require.NoError(t, err)
loader, err := NewDynamicCALoader([]string{caFile}, false)
require.NoError(t, err)
done := make(chan bool)
defer close(done)
err = loader.StartWatching(done)
require.NoError(t, err)
// Give watcher time to start
time.Sleep(100 * time.Millisecond)
pool1, err := loader.GetCertPool()
require.NoError(t, err)
// Update the file
newCertPEM := generateTestCert(t, "New Test CA")
err = os.WriteFile(caFile, newCertPEM, 0600)
require.NoError(t, err)
// Give watcher time to detect and reload
time.Sleep(200 * time.Millisecond)
pool2, err := loader.GetCertPool()
require.NoError(t, err)
// After file change, the pool should be reloaded
assert.NotSame(t, pool1, pool2)
}

View File

@ -0,0 +1,88 @@
package requests
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
)
// DynamicTLSTransport wraps http.Transport with dynamic CA verification.
// It uses VerifyPeerCertificate callback to perform certificate verification
// with a dynamically reloadable CA pool, enabling hot-reload of CA certificates
// without restarting the application.
type DynamicTLSTransport struct {
base *http.Transport
caLoader *DynamicCALoader
}
// NewDynamicTLSTransport creates a new transport that uses the provided
// DynamicCALoader for certificate verification. The CA certificates can
// be reloaded at runtime without affecting in-flight requests.
func NewDynamicTLSTransport(caLoader *DynamicCALoader) *DynamicTLSTransport {
base := http.DefaultTransport.(*http.Transport).Clone()
dt := &DynamicTLSTransport{
base: base,
caLoader: caLoader,
}
base.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
// We set InsecureSkipVerify to true because we perform our own
// verification in VerifyPeerCertificate. This allows us to use
// a dynamically updated CA pool for verification.
InsecureSkipVerify: true, // #nosec G402 -- verification done in VerifyPeerCertificate
VerifyPeerCertificate: dt.verifyPeerCertificate,
}
return dt
}
// RoundTrip implements http.RoundTripper
func (dt *DynamicTLSTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return dt.base.RoundTrip(req)
}
// verifyPeerCertificate is called during TLS handshake to verify the server's
// certificate chain using the current CA pool from the DynamicCALoader.
// This enables hot-reload of CA certificates - each new connection will use
// the latest CA certificates.
func (dt *DynamicTLSTransport) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return fmt.Errorf("no certificates provided by server")
}
pool, err := dt.caLoader.GetCertPool()
if err != nil {
return fmt.Errorf("failed to get CA pool: %w", err)
}
// Parse all certificates in the chain
certs := make([]*x509.Certificate, len(rawCerts))
for i, raw := range rawCerts {
cert, err := x509.ParseCertificate(raw)
if err != nil {
return fmt.Errorf("failed to parse certificate: %w", err)
}
certs[i] = cert
}
// Build verification options with the dynamic CA pool
opts := x509.VerifyOptions{
Roots: pool,
Intermediates: x509.NewCertPool(),
}
// Add intermediate certificates (all except the leaf)
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
// Verify the leaf certificate
if _, err := certs[0].Verify(opts); err != nil {
return fmt.Errorf("certificate verification failed: %w", err)
}
return nil
}

View File

@ -2,6 +2,7 @@ package requests
import (
"net/http"
"sync/atomic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/version"
)
@ -17,13 +18,40 @@ func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error
return t.next.RoundTrip(r)
}
// defaultTransport holds a pointer to the current transport (can be swapped atomically)
var defaultTransport atomic.Pointer[http.RoundTripper]
// transportProxy implements http.RoundTripper and delegates to the atomic pointer.
// This allows the underlying transport to be swapped at runtime for CA reload.
type transportProxy struct{}
func (t *transportProxy) RoundTrip(req *http.Request) (*http.Response, error) {
transport := defaultTransport.Load()
if transport == nil {
return http.DefaultTransport.RoundTrip(req)
}
return (*transport).RoundTrip(req)
}
// DefaultHTTPClient is the shared HTTP client used for provider requests.
// It uses a transport proxy that supports runtime transport swapping for CA reload.
var DefaultHTTPClient = &http.Client{Transport: &userAgentTransport{
next: DefaultTransport,
next: &transportProxy{},
userAgent: "oauth2-proxy/" + version.VERSION,
}}
// DefaultTransport is kept for backward compatibility.
// New code should use SetDefaultTransport to update the transport.
var DefaultTransport = http.DefaultTransport
// SetDefaultTransport atomically sets the default transport used by DefaultHTTPClient.
// This is used to enable dynamic CA certificate reloading.
func SetDefaultTransport(rt http.RoundTripper) {
defaultTransport.Store(&rt)
// Also update the legacy variable for backward compatibility with existing code
DefaultTransport = rt
}
func setDefaultUserAgent(header http.Header, userAgent string) {
if header != nil && len(header.Values("User-Agent")) == 0 {
header.Set("User-Agent", userAgent)