feat(requests): add dynamic CA certificate reloading
Add DynamicCALoader and DynamicTLSTransport for hot-reloadable CA certificates without requiring application restart. - DynamicCALoader: watches CA files via fsnotify and reloads on change, supports Kubernetes ConfigMap/Secret symlink replacement pattern - DynamicTLSTransport: wraps http.Transport with dynamic CA verification using VerifyPeerCertificate callback - Add atomic transport proxy to http.go for runtime transport swapping Signed-off-by: Niki Dokovski <nickytd@gmail.com>
This commit is contained in:
parent
b5c8df7988
commit
c77efe2300
|
|
@ -0,0 +1,69 @@
|
|||
package requests
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/watcher"
|
||||
)
|
||||
|
||||
// DynamicCALoader provides hot-reloadable CA certificates.
|
||||
// It caches the certificate pool and reloads from disk when triggered
|
||||
// by file system events via the watcher.
|
||||
type DynamicCALoader struct {
|
||||
caFiles []string
|
||||
useSystemTrustStore bool
|
||||
certPool atomic.Pointer[x509.CertPool]
|
||||
}
|
||||
|
||||
// NewDynamicCALoader creates a new DynamicCALoader that loads CA certificates
|
||||
// from the specified files. The certificates are reloaded automatically when
|
||||
// files change on disk.
|
||||
func NewDynamicCALoader(caFiles []string, useSystemTrustStore bool) (*DynamicCALoader, error) {
|
||||
d := &DynamicCALoader{
|
||||
caFiles: caFiles,
|
||||
useSystemTrustStore: useSystemTrustStore,
|
||||
}
|
||||
if err := d.reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// GetCertPool returns the current CA certificate pool.
|
||||
func (d *DynamicCALoader) GetCertPool() (*x509.CertPool, error) {
|
||||
return d.certPool.Load(), nil
|
||||
}
|
||||
|
||||
// reload loads CA certificates from disk and updates the cached pool.
|
||||
func (d *DynamicCALoader) reload() error {
|
||||
pool, err := util.GetCertPool(d.caFiles, d.useSystemTrustStore)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.certPool.Store(pool)
|
||||
logger.Printf("CA certificates reloaded from %v", d.caFiles)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceReload forces immediate reload of CA certificates.
|
||||
// This is called by file watcher when CA files change on disk.
|
||||
func (d *DynamicCALoader) ForceReload() {
|
||||
if err := d.reload(); err != nil {
|
||||
logger.Errorf("CA reload failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// StartWatching sets up file watchers for all CA files.
|
||||
// When any CA file changes, the certificates are automatically reloaded.
|
||||
// This supports Kubernetes ConfigMap/Secret mounts which use symlink replacement.
|
||||
func (d *DynamicCALoader) StartWatching(done <-chan bool) error {
|
||||
for _, caFile := range d.caFiles {
|
||||
if err := watcher.WatchFileForUpdates(caFile, done, d.ForceReload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,225 @@
|
|||
package requests
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// generateTestCert creates a self-signed certificate for testing
|
||||
func generateTestCert(t *testing.T, commonName string) (certPEM []byte) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
require.NoError(t, err)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: commonName,
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
return certPEM
|
||||
}
|
||||
|
||||
// createTempCAFile creates a temporary CA file for testing
|
||||
func createTempCAFile(t *testing.T, dir string, certPEM []byte) string {
|
||||
t.Helper()
|
||||
|
||||
file, err := os.CreateTemp(dir, "ca-*.pem")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = file.Write(certPEM)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, file.Close())
|
||||
|
||||
return file.Name()
|
||||
}
|
||||
|
||||
func TestNewDynamicCALoader(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certPEM := generateTestCert(t, "Test CA")
|
||||
caFile := createTempCAFile(t, tempDir, certPEM)
|
||||
|
||||
loader, err := NewDynamicCALoader([]string{caFile}, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loader)
|
||||
|
||||
pool, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pool)
|
||||
}
|
||||
|
||||
func TestNewDynamicCALoader_InvalidFile(t *testing.T) {
|
||||
loader, err := NewDynamicCALoader([]string{"/nonexistent/ca.pem"}, false)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, loader)
|
||||
}
|
||||
|
||||
func TestNewDynamicCALoader_EmptyFiles(t *testing.T) {
|
||||
loader, err := NewDynamicCALoader([]string{}, false)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, loader)
|
||||
}
|
||||
|
||||
func TestDynamicCALoader_GetCertPool_ReturnsCached(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certPEM := generateTestCert(t, "Test CA")
|
||||
caFile := createTempCAFile(t, tempDir, certPEM)
|
||||
|
||||
loader, err := NewDynamicCALoader([]string{caFile}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Multiple calls should return the same cached pool
|
||||
pool1, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
|
||||
pool2, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both should be the same pointer (cached)
|
||||
assert.Same(t, pool1, pool2)
|
||||
}
|
||||
|
||||
func TestDynamicCALoader_ForceReload(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certPEM := generateTestCert(t, "Test CA")
|
||||
caFile := createTempCAFile(t, tempDir, certPEM)
|
||||
|
||||
loader, err := NewDynamicCALoader([]string{caFile}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
pool1, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the file with new cert
|
||||
newCertPEM := generateTestCert(t, "New Test CA")
|
||||
err = os.WriteFile(caFile, newCertPEM, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Force reload
|
||||
loader.ForceReload()
|
||||
|
||||
pool2, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Pools should be different after reload
|
||||
assert.NotSame(t, pool1, pool2)
|
||||
}
|
||||
|
||||
func TestDynamicCALoader_ForceReload_ErrorKeepsOldCerts(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certPEM := generateTestCert(t, "Test CA")
|
||||
caFile := createTempCAFile(t, tempDir, certPEM)
|
||||
|
||||
loader, err := NewDynamicCALoader([]string{caFile}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
pool1, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the file to cause reload error
|
||||
err = os.Remove(caFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ForceReload should fail but keep old certs
|
||||
loader.ForceReload()
|
||||
|
||||
// GetCertPool should still return the old pool
|
||||
pool2, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, pool1, pool2)
|
||||
}
|
||||
|
||||
func TestDynamicCALoader_MultipleCertFiles(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cert1PEM := generateTestCert(t, "Test CA 1")
|
||||
cert2PEM := generateTestCert(t, "Test CA 2")
|
||||
|
||||
caFile1 := createTempCAFile(t, tempDir, cert1PEM)
|
||||
caFile2 := createTempCAFile(t, tempDir, cert2PEM)
|
||||
|
||||
loader, err := NewDynamicCALoader([]string{caFile1, caFile2}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
pool, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pool)
|
||||
}
|
||||
|
||||
func TestDynamicTLSTransport_Creation(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certPEM := generateTestCert(t, "Test CA")
|
||||
caFile := createTempCAFile(t, tempDir, certPEM)
|
||||
|
||||
loader, err := NewDynamicCALoader([]string{caFile}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport := NewDynamicTLSTransport(loader)
|
||||
require.NotNil(t, transport)
|
||||
require.NotNil(t, transport.base)
|
||||
require.NotNil(t, transport.base.TLSClientConfig)
|
||||
}
|
||||
|
||||
func TestDynamicCALoader_StartWatching(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certPEM := generateTestCert(t, "Test CA")
|
||||
caFile := filepath.Join(tempDir, "ca.pem")
|
||||
err := os.WriteFile(caFile, certPEM, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
loader, err := NewDynamicCALoader([]string{caFile}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan bool)
|
||||
defer close(done)
|
||||
|
||||
err = loader.StartWatching(done)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give watcher time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
pool1, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the file
|
||||
newCertPEM := generateTestCert(t, "New Test CA")
|
||||
err = os.WriteFile(caFile, newCertPEM, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give watcher time to detect and reload
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
pool2, err := loader.GetCertPool()
|
||||
require.NoError(t, err)
|
||||
|
||||
// After file change, the pool should be reloaded
|
||||
assert.NotSame(t, pool1, pool2)
|
||||
}
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
package requests
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// DynamicTLSTransport wraps http.Transport with dynamic CA verification.
|
||||
// It uses VerifyPeerCertificate callback to perform certificate verification
|
||||
// with a dynamically reloadable CA pool, enabling hot-reload of CA certificates
|
||||
// without restarting the application.
|
||||
type DynamicTLSTransport struct {
|
||||
base *http.Transport
|
||||
caLoader *DynamicCALoader
|
||||
}
|
||||
|
||||
// NewDynamicTLSTransport creates a new transport that uses the provided
|
||||
// DynamicCALoader for certificate verification. The CA certificates can
|
||||
// be reloaded at runtime without affecting in-flight requests.
|
||||
func NewDynamicTLSTransport(caLoader *DynamicCALoader) *DynamicTLSTransport {
|
||||
base := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
dt := &DynamicTLSTransport{
|
||||
base: base,
|
||||
caLoader: caLoader,
|
||||
}
|
||||
|
||||
base.TLSClientConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
// We set InsecureSkipVerify to true because we perform our own
|
||||
// verification in VerifyPeerCertificate. This allows us to use
|
||||
// a dynamically updated CA pool for verification.
|
||||
InsecureSkipVerify: true, // #nosec G402 -- verification done in VerifyPeerCertificate
|
||||
VerifyPeerCertificate: dt.verifyPeerCertificate,
|
||||
}
|
||||
|
||||
return dt
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper
|
||||
func (dt *DynamicTLSTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return dt.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// verifyPeerCertificate is called during TLS handshake to verify the server's
|
||||
// certificate chain using the current CA pool from the DynamicCALoader.
|
||||
// This enables hot-reload of CA certificates - each new connection will use
|
||||
// the latest CA certificates.
|
||||
func (dt *DynamicTLSTransport) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
if len(rawCerts) == 0 {
|
||||
return fmt.Errorf("no certificates provided by server")
|
||||
}
|
||||
|
||||
pool, err := dt.caLoader.GetCertPool()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get CA pool: %w", err)
|
||||
}
|
||||
|
||||
// Parse all certificates in the chain
|
||||
certs := make([]*x509.Certificate, len(rawCerts))
|
||||
for i, raw := range rawCerts {
|
||||
cert, err := x509.ParseCertificate(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse certificate: %w", err)
|
||||
}
|
||||
certs[i] = cert
|
||||
}
|
||||
|
||||
// Build verification options with the dynamic CA pool
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: pool,
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
|
||||
// Add intermediate certificates (all except the leaf)
|
||||
for _, cert := range certs[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
// Verify the leaf certificate
|
||||
if _, err := certs[0].Verify(opts); err != nil {
|
||||
return fmt.Errorf("certificate verification failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ package requests
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/version"
|
||||
)
|
||||
|
|
@ -17,13 +18,40 @@ func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error
|
|||
return t.next.RoundTrip(r)
|
||||
}
|
||||
|
||||
// defaultTransport holds a pointer to the current transport (can be swapped atomically)
|
||||
var defaultTransport atomic.Pointer[http.RoundTripper]
|
||||
|
||||
// transportProxy implements http.RoundTripper and delegates to the atomic pointer.
|
||||
// This allows the underlying transport to be swapped at runtime for CA reload.
|
||||
type transportProxy struct{}
|
||||
|
||||
func (t *transportProxy) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
transport := defaultTransport.Load()
|
||||
if transport == nil {
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
return (*transport).RoundTrip(req)
|
||||
}
|
||||
|
||||
// DefaultHTTPClient is the shared HTTP client used for provider requests.
|
||||
// It uses a transport proxy that supports runtime transport swapping for CA reload.
|
||||
var DefaultHTTPClient = &http.Client{Transport: &userAgentTransport{
|
||||
next: DefaultTransport,
|
||||
next: &transportProxy{},
|
||||
userAgent: "oauth2-proxy/" + version.VERSION,
|
||||
}}
|
||||
|
||||
// DefaultTransport is kept for backward compatibility.
|
||||
// New code should use SetDefaultTransport to update the transport.
|
||||
var DefaultTransport = http.DefaultTransport
|
||||
|
||||
// SetDefaultTransport atomically sets the default transport used by DefaultHTTPClient.
|
||||
// This is used to enable dynamic CA certificate reloading.
|
||||
func SetDefaultTransport(rt http.RoundTripper) {
|
||||
defaultTransport.Store(&rt)
|
||||
// Also update the legacy variable for backward compatibility with existing code
|
||||
DefaultTransport = rt
|
||||
}
|
||||
|
||||
func setDefaultUserAgent(header http.Header, userAgent string) {
|
||||
if header != nil && len(header.Values("User-Agent")) == 0 {
|
||||
header.Set("User-Agent", userAgent)
|
||||
|
|
|
|||
Loading…
Reference in New Issue